Win and linux can run the code

This commit is contained in:
MZ YANG
2026-01-22 17:39:31 +08:00
parent c3f750cd9d
commit f37a8ce179
22 changed files with 32572 additions and 87 deletions

215
example/platform_utils.py Normal file
View File

@@ -0,0 +1,215 @@
#!/usr/bin/env python3
"""跨平台工具函数:设备检测和路径处理"""
import os
import sys
import platform
from pathlib import Path
from typing import Optional, Union
# 尝试导入torch但不是强制的
try:
import torch
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
torch = None
def get_platform_info() -> dict:
"""获取平台信息"""
return {
"system": platform.system(),
"release": platform.release(),
"version": platform.version(),
"machine": platform.machine(),
"processor": platform.processor(),
"python_version": platform.python_version(),
"python_executable": sys.executable,
"current_dir": os.getcwd(),
}
def is_windows() -> bool:
"""检查是否在Windows上运行"""
return platform.system().lower() == "windows"
def is_linux() -> bool:
"""检查是否在Linux上运行"""
return platform.system().lower() == "linux"
def is_macos() -> bool:
"""检查是否在macOS上运行"""
return platform.system().lower() == "darwin"
def resolve_device(device: str = "auto", verbose: bool = True) -> str:
"""
解析设备字符串,自动检测最佳设备
Args:
device: "auto", "cpu", "cuda", or "mps" (macOS)
verbose: 是否打印信息
Returns:
设备字符串: "cpu", "cuda", or "mps"
"""
if not TORCH_AVAILABLE:
if verbose:
print("警告: PyTorch未安装只能使用CPU")
return "cpu"
device = device.lower().strip()
if device == "cpu":
if verbose:
print("设备: CPU (用户指定)")
return "cpu"
if device == "cuda":
if torch.cuda.is_available():
if verbose:
gpu_count = torch.cuda.device_count()
gpu_name = torch.cuda.get_device_name(0) if gpu_count > 0 else "未知"
print(f"设备: CUDA (GPU: {gpu_name}, 数量: {gpu_count})")
return "cuda"
else:
if verbose:
print("警告: 指定了CUDA但不可用自动回退到CPU")
print("提示: 检查PyTorch是否安装了GPU版本")
return "cpu"
if device == "mps" and is_macos():
# macOS的Metal Performance Shaders
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
if verbose:
print("设备: MPS (macOS Metal)")
return "mps"
else:
if verbose:
print("警告: MPS不可用回退到CPU")
return "cpu"
# 自动检测
if torch.cuda.is_available():
if verbose:
gpu_count = torch.cuda.device_count()
gpu_name = torch.cuda.get_device_name(0) if gpu_count > 0 else "未知"
print(f"设备: CUDA (自动检测到GPU: {gpu_name})")
return "cuda"
if is_macos() and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
if verbose:
print("设备: MPS (macOS Metal, 自动检测)")
return "mps"
if verbose:
print("设备: CPU (自动检测无GPU可用)")
return "cpu"
def safe_path(path: Union[str, Path]) -> str:
"""
安全地处理路径,确保跨平台兼容性
Args:
path: 路径字符串或Path对象
Returns:
字符串形式的路径
"""
if isinstance(path, Path):
return str(path)
# 如果是字符串,确保使用正确的路径分隔符
path_str = str(path)
# 替换可能存在的错误分隔符
if is_windows():
# Windows上确保使用反斜杠
path_str = path_str.replace('/', '\\')
else:
# Linux/macOS上确保使用正斜杠
path_str = path_str.replace('\\', '/')
return path_str
def ensure_dir(path: Union[str, Path]) -> Path:
"""
确保目录存在,如果不存在则创建
Args:
path: 目录路径
Returns:
Path对象
"""
path_obj = Path(path) if isinstance(path, str) else path
path_obj.mkdir(parents=True, exist_ok=True)
return path_obj
def get_relative_path(base: Union[str, Path], target: Union[str, Path]) -> Path:
"""
获取相对于基路径的相对路径
Args:
base: 基路径
target: 目标路径
Returns:
相对路径的Path对象
"""
base_path = Path(base) if isinstance(base, str) else base
target_path = Path(target) if isinstance(target, str) else target
# 如果目标路径是绝对路径,直接返回
if target_path.is_absolute():
return target_path
# 否则相对于基路径
return (base_path / target_path).resolve()
def print_platform_summary():
"""打印平台摘要信息"""
info = get_platform_info()
print("=" * 50)
print("平台信息:")
print(f" 系统: {info['system']} {info['release']}")
print(f" 处理器: {info['processor']}")
print(f" Python: {info['python_version']}")
print(f" 当前目录: {info['current_dir']}")
if TORCH_AVAILABLE:
print(f" PyTorch: {torch.__version__}")
print(f" CUDA可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f" GPU数量: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
else:
print(" PyTorch: 未安装")
print("=" * 50)
if __name__ == "__main__":
# 测试代码
print_platform_summary()
print("\n设备检测测试:")
for device in ["auto", "cpu", "cuda", "mps"]:
try:
result = resolve_device(device, verbose=True)
print(f" 输入: '{device}' -> 输出: '{result}'")
except Exception as e:
print(f" 输入: '{device}' -> 错误: {e}")
print("\n路径处理测试:")
test_path = "some/path/to/file.txt"
print(f" 原始路径: {test_path}")
print(f" 安全路径: {safe_path(test_path)}")