#!/usr/bin/env python3 """跨平台工具函数:设备检测和路径处理""" import os import sys import platform from pathlib import Path from typing import Optional, Union # 尝试导入torch,但不是强制的 try: import torch TORCH_AVAILABLE = True except ImportError: TORCH_AVAILABLE = False torch = None def get_platform_info() -> dict: """获取平台信息""" return { "system": platform.system(), "release": platform.release(), "version": platform.version(), "machine": platform.machine(), "processor": platform.processor(), "python_version": platform.python_version(), "python_executable": sys.executable, "current_dir": os.getcwd(), } def is_windows() -> bool: """检查是否在Windows上运行""" return platform.system().lower() == "windows" def is_linux() -> bool: """检查是否在Linux上运行""" return platform.system().lower() == "linux" def is_macos() -> bool: """检查是否在macOS上运行""" return platform.system().lower() == "darwin" def resolve_device(device: str = "auto", verbose: bool = True) -> str: """ 解析设备字符串,自动检测最佳设备 Args: device: "auto", "cpu", "cuda", or "mps" (macOS) verbose: 是否打印信息 Returns: 设备字符串: "cpu", "cuda", or "mps" """ if not TORCH_AVAILABLE: if verbose: print("警告: PyTorch未安装,只能使用CPU") return "cpu" device = device.lower().strip() if device == "cpu": if verbose: print("设备: CPU (用户指定)") return "cpu" if device == "cuda": if torch.cuda.is_available(): if verbose: gpu_count = torch.cuda.device_count() gpu_name = torch.cuda.get_device_name(0) if gpu_count > 0 else "未知" print(f"设备: CUDA (GPU: {gpu_name}, 数量: {gpu_count})") return "cuda" else: if verbose: print("警告: 指定了CUDA但不可用,自动回退到CPU") print("提示: 检查PyTorch是否安装了GPU版本") return "cpu" if device == "mps" and is_macos(): # macOS的Metal Performance Shaders if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): if verbose: print("设备: MPS (macOS Metal)") return "mps" else: if verbose: print("警告: MPS不可用,回退到CPU") return "cpu" # 自动检测 if torch.cuda.is_available(): if verbose: gpu_count = torch.cuda.device_count() gpu_name = torch.cuda.get_device_name(0) if gpu_count > 0 else "未知" print(f"设备: CUDA (自动检测到GPU: {gpu_name})") return "cuda" if is_macos() and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): if verbose: print("设备: MPS (macOS Metal, 自动检测)") return "mps" if verbose: print("设备: CPU (自动检测,无GPU可用)") return "cpu" def safe_path(path: Union[str, Path]) -> str: """ 安全地处理路径,确保跨平台兼容性 Args: path: 路径字符串或Path对象 Returns: 字符串形式的路径 """ if isinstance(path, Path): return str(path) # 如果是字符串,确保使用正确的路径分隔符 path_str = str(path) # 替换可能存在的错误分隔符 if is_windows(): # Windows上,确保使用反斜杠 path_str = path_str.replace('/', '\\') else: # Linux/macOS上,确保使用正斜杠 path_str = path_str.replace('\\', '/') return path_str def ensure_dir(path: Union[str, Path]) -> Path: """ 确保目录存在,如果不存在则创建 Args: path: 目录路径 Returns: Path对象 """ path_obj = Path(path) if isinstance(path, str) else path path_obj.mkdir(parents=True, exist_ok=True) return path_obj def get_relative_path(base: Union[str, Path], target: Union[str, Path]) -> Path: """ 获取相对于基路径的相对路径 Args: base: 基路径 target: 目标路径 Returns: 相对路径的Path对象 """ base_path = Path(base) if isinstance(base, str) else base target_path = Path(target) if isinstance(target, str) else target # 如果目标路径是绝对路径,直接返回 if target_path.is_absolute(): return target_path # 否则相对于基路径 return (base_path / target_path).resolve() def resolve_path(base: Union[str, Path], target: Union[str, Path]) -> Path: """ Resolve target path against base if target is relative. Args: base: base directory target: target path (absolute or relative) Returns: Absolute Path """ base_path = Path(base) if isinstance(base, str) else base target_path = Path(target) if isinstance(target, str) else target if target_path.is_absolute(): return target_path return (base_path / target_path).resolve() def print_platform_summary(): """打印平台摘要信息""" info = get_platform_info() print("=" * 50) print("平台信息:") print(f" 系统: {info['system']} {info['release']}") print(f" 处理器: {info['processor']}") print(f" Python: {info['python_version']}") print(f" 当前目录: {info['current_dir']}") if TORCH_AVAILABLE: print(f" PyTorch: {torch.__version__}") print(f" CUDA可用: {torch.cuda.is_available()}") if torch.cuda.is_available(): print(f" GPU数量: {torch.cuda.device_count()}") for i in range(torch.cuda.device_count()): print(f" GPU {i}: {torch.cuda.get_device_name(i)}") else: print(" PyTorch: 未安装") print("=" * 50) if __name__ == "__main__": # 测试代码 print_platform_summary() print("\n设备检测测试:") for device in ["auto", "cpu", "cuda", "mps"]: try: result = resolve_device(device, verbose=True) print(f" 输入: '{device}' -> 输出: '{result}'") except Exception as e: print(f" 输入: '{device}' -> 错误: {e}") print("\n路径处理测试:") test_path = "some/path/to/file.txt" print(f" 原始路径: {test_path}") print(f" 安全路径: {safe_path(test_path)}")