215 lines
6.0 KiB
Python
215 lines
6.0 KiB
Python
#!/usr/bin/env python3
|
||
"""跨平台工具函数:设备检测和路径处理"""
|
||
|
||
import os
|
||
import sys
|
||
import platform
|
||
from pathlib import Path
|
||
from typing import Optional, Union
|
||
|
||
# 尝试导入torch,但不是强制的
|
||
try:
|
||
import torch
|
||
TORCH_AVAILABLE = True
|
||
except ImportError:
|
||
TORCH_AVAILABLE = False
|
||
torch = None
|
||
|
||
|
||
def get_platform_info() -> dict:
|
||
"""获取平台信息"""
|
||
return {
|
||
"system": platform.system(),
|
||
"release": platform.release(),
|
||
"version": platform.version(),
|
||
"machine": platform.machine(),
|
||
"processor": platform.processor(),
|
||
"python_version": platform.python_version(),
|
||
"python_executable": sys.executable,
|
||
"current_dir": os.getcwd(),
|
||
}
|
||
|
||
|
||
def is_windows() -> bool:
|
||
"""检查是否在Windows上运行"""
|
||
return platform.system().lower() == "windows"
|
||
|
||
|
||
def is_linux() -> bool:
|
||
"""检查是否在Linux上运行"""
|
||
return platform.system().lower() == "linux"
|
||
|
||
|
||
def is_macos() -> bool:
|
||
"""检查是否在macOS上运行"""
|
||
return platform.system().lower() == "darwin"
|
||
|
||
|
||
def resolve_device(device: str = "auto", verbose: bool = True) -> str:
|
||
"""
|
||
解析设备字符串,自动检测最佳设备
|
||
|
||
Args:
|
||
device: "auto", "cpu", "cuda", or "mps" (macOS)
|
||
verbose: 是否打印信息
|
||
|
||
Returns:
|
||
设备字符串: "cpu", "cuda", or "mps"
|
||
"""
|
||
if not TORCH_AVAILABLE:
|
||
if verbose:
|
||
print("警告: PyTorch未安装,只能使用CPU")
|
||
return "cpu"
|
||
|
||
device = device.lower().strip()
|
||
|
||
if device == "cpu":
|
||
if verbose:
|
||
print("设备: CPU (用户指定)")
|
||
return "cpu"
|
||
|
||
if device == "cuda":
|
||
if torch.cuda.is_available():
|
||
if verbose:
|
||
gpu_count = torch.cuda.device_count()
|
||
gpu_name = torch.cuda.get_device_name(0) if gpu_count > 0 else "未知"
|
||
print(f"设备: CUDA (GPU: {gpu_name}, 数量: {gpu_count})")
|
||
return "cuda"
|
||
else:
|
||
if verbose:
|
||
print("警告: 指定了CUDA但不可用,自动回退到CPU")
|
||
print("提示: 检查PyTorch是否安装了GPU版本")
|
||
return "cpu"
|
||
|
||
if device == "mps" and is_macos():
|
||
# macOS的Metal Performance Shaders
|
||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||
if verbose:
|
||
print("设备: MPS (macOS Metal)")
|
||
return "mps"
|
||
else:
|
||
if verbose:
|
||
print("警告: MPS不可用,回退到CPU")
|
||
return "cpu"
|
||
|
||
# 自动检测
|
||
if torch.cuda.is_available():
|
||
if verbose:
|
||
gpu_count = torch.cuda.device_count()
|
||
gpu_name = torch.cuda.get_device_name(0) if gpu_count > 0 else "未知"
|
||
print(f"设备: CUDA (自动检测到GPU: {gpu_name})")
|
||
return "cuda"
|
||
|
||
if is_macos() and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||
if verbose:
|
||
print("设备: MPS (macOS Metal, 自动检测)")
|
||
return "mps"
|
||
|
||
if verbose:
|
||
print("设备: CPU (自动检测,无GPU可用)")
|
||
return "cpu"
|
||
|
||
|
||
def safe_path(path: Union[str, Path]) -> str:
|
||
"""
|
||
安全地处理路径,确保跨平台兼容性
|
||
|
||
Args:
|
||
path: 路径字符串或Path对象
|
||
|
||
Returns:
|
||
字符串形式的路径
|
||
"""
|
||
if isinstance(path, Path):
|
||
return str(path)
|
||
|
||
# 如果是字符串,确保使用正确的路径分隔符
|
||
path_str = str(path)
|
||
|
||
# 替换可能存在的错误分隔符
|
||
if is_windows():
|
||
# Windows上,确保使用反斜杠
|
||
path_str = path_str.replace('/', '\\')
|
||
else:
|
||
# Linux/macOS上,确保使用正斜杠
|
||
path_str = path_str.replace('\\', '/')
|
||
|
||
return path_str
|
||
|
||
|
||
def ensure_dir(path: Union[str, Path]) -> Path:
|
||
"""
|
||
确保目录存在,如果不存在则创建
|
||
|
||
Args:
|
||
path: 目录路径
|
||
|
||
Returns:
|
||
Path对象
|
||
"""
|
||
path_obj = Path(path) if isinstance(path, str) else path
|
||
path_obj.mkdir(parents=True, exist_ok=True)
|
||
return path_obj
|
||
|
||
|
||
def get_relative_path(base: Union[str, Path], target: Union[str, Path]) -> Path:
|
||
"""
|
||
获取相对于基路径的相对路径
|
||
|
||
Args:
|
||
base: 基路径
|
||
target: 目标路径
|
||
|
||
Returns:
|
||
相对路径的Path对象
|
||
"""
|
||
base_path = Path(base) if isinstance(base, str) else base
|
||
target_path = Path(target) if isinstance(target, str) else target
|
||
|
||
# 如果目标路径是绝对路径,直接返回
|
||
if target_path.is_absolute():
|
||
return target_path
|
||
|
||
# 否则相对于基路径
|
||
return (base_path / target_path).resolve()
|
||
|
||
|
||
def print_platform_summary():
|
||
"""打印平台摘要信息"""
|
||
info = get_platform_info()
|
||
print("=" * 50)
|
||
print("平台信息:")
|
||
print(f" 系统: {info['system']} {info['release']}")
|
||
print(f" 处理器: {info['processor']}")
|
||
print(f" Python: {info['python_version']}")
|
||
print(f" 当前目录: {info['current_dir']}")
|
||
|
||
if TORCH_AVAILABLE:
|
||
print(f" PyTorch: {torch.__version__}")
|
||
print(f" CUDA可用: {torch.cuda.is_available()}")
|
||
if torch.cuda.is_available():
|
||
print(f" GPU数量: {torch.cuda.device_count()}")
|
||
for i in range(torch.cuda.device_count()):
|
||
print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
|
||
else:
|
||
print(" PyTorch: 未安装")
|
||
|
||
print("=" * 50)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 测试代码
|
||
print_platform_summary()
|
||
|
||
print("\n设备检测测试:")
|
||
for device in ["auto", "cpu", "cuda", "mps"]:
|
||
try:
|
||
result = resolve_device(device, verbose=True)
|
||
print(f" 输入: '{device}' -> 输出: '{result}'")
|
||
except Exception as e:
|
||
print(f" 输入: '{device}' -> 错误: {e}")
|
||
|
||
print("\n路径处理测试:")
|
||
test_path = "some/path/to/file.txt"
|
||
print(f" 原始路径: {test_path}")
|
||
print(f" 安全路径: {safe_path(test_path)}") |