Win and linux can run the code
This commit is contained in:
215
example/platform_utils.py
Normal file
215
example/platform_utils.py
Normal file
@@ -0,0 +1,215 @@
|
||||
#!/usr/bin/env python3
|
||||
"""跨平台工具函数:设备检测和路径处理"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
# 尝试导入torch,但不是强制的
|
||||
try:
|
||||
import torch
|
||||
TORCH_AVAILABLE = True
|
||||
except ImportError:
|
||||
TORCH_AVAILABLE = False
|
||||
torch = None
|
||||
|
||||
|
||||
def get_platform_info() -> dict:
|
||||
"""获取平台信息"""
|
||||
return {
|
||||
"system": platform.system(),
|
||||
"release": platform.release(),
|
||||
"version": platform.version(),
|
||||
"machine": platform.machine(),
|
||||
"processor": platform.processor(),
|
||||
"python_version": platform.python_version(),
|
||||
"python_executable": sys.executable,
|
||||
"current_dir": os.getcwd(),
|
||||
}
|
||||
|
||||
|
||||
def is_windows() -> bool:
|
||||
"""检查是否在Windows上运行"""
|
||||
return platform.system().lower() == "windows"
|
||||
|
||||
|
||||
def is_linux() -> bool:
|
||||
"""检查是否在Linux上运行"""
|
||||
return platform.system().lower() == "linux"
|
||||
|
||||
|
||||
def is_macos() -> bool:
|
||||
"""检查是否在macOS上运行"""
|
||||
return platform.system().lower() == "darwin"
|
||||
|
||||
|
||||
def resolve_device(device: str = "auto", verbose: bool = True) -> str:
|
||||
"""
|
||||
解析设备字符串,自动检测最佳设备
|
||||
|
||||
Args:
|
||||
device: "auto", "cpu", "cuda", or "mps" (macOS)
|
||||
verbose: 是否打印信息
|
||||
|
||||
Returns:
|
||||
设备字符串: "cpu", "cuda", or "mps"
|
||||
"""
|
||||
if not TORCH_AVAILABLE:
|
||||
if verbose:
|
||||
print("警告: PyTorch未安装,只能使用CPU")
|
||||
return "cpu"
|
||||
|
||||
device = device.lower().strip()
|
||||
|
||||
if device == "cpu":
|
||||
if verbose:
|
||||
print("设备: CPU (用户指定)")
|
||||
return "cpu"
|
||||
|
||||
if device == "cuda":
|
||||
if torch.cuda.is_available():
|
||||
if verbose:
|
||||
gpu_count = torch.cuda.device_count()
|
||||
gpu_name = torch.cuda.get_device_name(0) if gpu_count > 0 else "未知"
|
||||
print(f"设备: CUDA (GPU: {gpu_name}, 数量: {gpu_count})")
|
||||
return "cuda"
|
||||
else:
|
||||
if verbose:
|
||||
print("警告: 指定了CUDA但不可用,自动回退到CPU")
|
||||
print("提示: 检查PyTorch是否安装了GPU版本")
|
||||
return "cpu"
|
||||
|
||||
if device == "mps" and is_macos():
|
||||
# macOS的Metal Performance Shaders
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
if verbose:
|
||||
print("设备: MPS (macOS Metal)")
|
||||
return "mps"
|
||||
else:
|
||||
if verbose:
|
||||
print("警告: MPS不可用,回退到CPU")
|
||||
return "cpu"
|
||||
|
||||
# 自动检测
|
||||
if torch.cuda.is_available():
|
||||
if verbose:
|
||||
gpu_count = torch.cuda.device_count()
|
||||
gpu_name = torch.cuda.get_device_name(0) if gpu_count > 0 else "未知"
|
||||
print(f"设备: CUDA (自动检测到GPU: {gpu_name})")
|
||||
return "cuda"
|
||||
|
||||
if is_macos() and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
if verbose:
|
||||
print("设备: MPS (macOS Metal, 自动检测)")
|
||||
return "mps"
|
||||
|
||||
if verbose:
|
||||
print("设备: CPU (自动检测,无GPU可用)")
|
||||
return "cpu"
|
||||
|
||||
|
||||
def safe_path(path: Union[str, Path]) -> str:
|
||||
"""
|
||||
安全地处理路径,确保跨平台兼容性
|
||||
|
||||
Args:
|
||||
path: 路径字符串或Path对象
|
||||
|
||||
Returns:
|
||||
字符串形式的路径
|
||||
"""
|
||||
if isinstance(path, Path):
|
||||
return str(path)
|
||||
|
||||
# 如果是字符串,确保使用正确的路径分隔符
|
||||
path_str = str(path)
|
||||
|
||||
# 替换可能存在的错误分隔符
|
||||
if is_windows():
|
||||
# Windows上,确保使用反斜杠
|
||||
path_str = path_str.replace('/', '\\')
|
||||
else:
|
||||
# Linux/macOS上,确保使用正斜杠
|
||||
path_str = path_str.replace('\\', '/')
|
||||
|
||||
return path_str
|
||||
|
||||
|
||||
def ensure_dir(path: Union[str, Path]) -> Path:
|
||||
"""
|
||||
确保目录存在,如果不存在则创建
|
||||
|
||||
Args:
|
||||
path: 目录路径
|
||||
|
||||
Returns:
|
||||
Path对象
|
||||
"""
|
||||
path_obj = Path(path) if isinstance(path, str) else path
|
||||
path_obj.mkdir(parents=True, exist_ok=True)
|
||||
return path_obj
|
||||
|
||||
|
||||
def get_relative_path(base: Union[str, Path], target: Union[str, Path]) -> Path:
|
||||
"""
|
||||
获取相对于基路径的相对路径
|
||||
|
||||
Args:
|
||||
base: 基路径
|
||||
target: 目标路径
|
||||
|
||||
Returns:
|
||||
相对路径的Path对象
|
||||
"""
|
||||
base_path = Path(base) if isinstance(base, str) else base
|
||||
target_path = Path(target) if isinstance(target, str) else target
|
||||
|
||||
# 如果目标路径是绝对路径,直接返回
|
||||
if target_path.is_absolute():
|
||||
return target_path
|
||||
|
||||
# 否则相对于基路径
|
||||
return (base_path / target_path).resolve()
|
||||
|
||||
|
||||
def print_platform_summary():
|
||||
"""打印平台摘要信息"""
|
||||
info = get_platform_info()
|
||||
print("=" * 50)
|
||||
print("平台信息:")
|
||||
print(f" 系统: {info['system']} {info['release']}")
|
||||
print(f" 处理器: {info['processor']}")
|
||||
print(f" Python: {info['python_version']}")
|
||||
print(f" 当前目录: {info['current_dir']}")
|
||||
|
||||
if TORCH_AVAILABLE:
|
||||
print(f" PyTorch: {torch.__version__}")
|
||||
print(f" CUDA可用: {torch.cuda.is_available()}")
|
||||
if torch.cuda.is_available():
|
||||
print(f" GPU数量: {torch.cuda.device_count()}")
|
||||
for i in range(torch.cuda.device_count()):
|
||||
print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
|
||||
else:
|
||||
print(" PyTorch: 未安装")
|
||||
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
print_platform_summary()
|
||||
|
||||
print("\n设备检测测试:")
|
||||
for device in ["auto", "cpu", "cuda", "mps"]:
|
||||
try:
|
||||
result = resolve_device(device, verbose=True)
|
||||
print(f" 输入: '{device}' -> 输出: '{result}'")
|
||||
except Exception as e:
|
||||
print(f" 输入: '{device}' -> 错误: {e}")
|
||||
|
||||
print("\n路径处理测试:")
|
||||
test_path = "some/path/to/file.txt"
|
||||
print(f" 原始路径: {test_path}")
|
||||
print(f" 安全路径: {safe_path(test_path)}")
|
||||
Reference in New Issue
Block a user