Files
mask-ddpm/example/platform_utils.py
2026-01-22 20:49:24 +08:00

234 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""跨平台工具函数:设备检测和路径处理"""
import os
import sys
import platform
from pathlib import Path
from typing import Optional, Union
# 尝试导入torch但不是强制的
try:
import torch
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
torch = None
def get_platform_info() -> dict:
"""获取平台信息"""
return {
"system": platform.system(),
"release": platform.release(),
"version": platform.version(),
"machine": platform.machine(),
"processor": platform.processor(),
"python_version": platform.python_version(),
"python_executable": sys.executable,
"current_dir": os.getcwd(),
}
def is_windows() -> bool:
"""检查是否在Windows上运行"""
return platform.system().lower() == "windows"
def is_linux() -> bool:
"""检查是否在Linux上运行"""
return platform.system().lower() == "linux"
def is_macos() -> bool:
"""检查是否在macOS上运行"""
return platform.system().lower() == "darwin"
def resolve_device(device: str = "auto", verbose: bool = True) -> str:
"""
解析设备字符串,自动检测最佳设备
Args:
device: "auto", "cpu", "cuda", or "mps" (macOS)
verbose: 是否打印信息
Returns:
设备字符串: "cpu", "cuda", or "mps"
"""
if not TORCH_AVAILABLE:
if verbose:
print("警告: PyTorch未安装只能使用CPU")
return "cpu"
device = device.lower().strip()
if device == "cpu":
if verbose:
print("设备: CPU (用户指定)")
return "cpu"
if device == "cuda":
if torch.cuda.is_available():
if verbose:
gpu_count = torch.cuda.device_count()
gpu_name = torch.cuda.get_device_name(0) if gpu_count > 0 else "未知"
print(f"设备: CUDA (GPU: {gpu_name}, 数量: {gpu_count})")
return "cuda"
else:
if verbose:
print("警告: 指定了CUDA但不可用自动回退到CPU")
print("提示: 检查PyTorch是否安装了GPU版本")
return "cpu"
if device == "mps" and is_macos():
# macOS的Metal Performance Shaders
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
if verbose:
print("设备: MPS (macOS Metal)")
return "mps"
else:
if verbose:
print("警告: MPS不可用回退到CPU")
return "cpu"
# 自动检测
if torch.cuda.is_available():
if verbose:
gpu_count = torch.cuda.device_count()
gpu_name = torch.cuda.get_device_name(0) if gpu_count > 0 else "未知"
print(f"设备: CUDA (自动检测到GPU: {gpu_name})")
return "cuda"
if is_macos() and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
if verbose:
print("设备: MPS (macOS Metal, 自动检测)")
return "mps"
if verbose:
print("设备: CPU (自动检测无GPU可用)")
return "cpu"
def safe_path(path: Union[str, Path]) -> str:
"""
安全地处理路径,确保跨平台兼容性
Args:
path: 路径字符串或Path对象
Returns:
字符串形式的路径
"""
if isinstance(path, Path):
return str(path)
# 如果是字符串,确保使用正确的路径分隔符
path_str = str(path)
# 替换可能存在的错误分隔符
if is_windows():
# Windows上确保使用反斜杠
path_str = path_str.replace('/', '\\')
else:
# Linux/macOS上确保使用正斜杠
path_str = path_str.replace('\\', '/')
return path_str
def ensure_dir(path: Union[str, Path]) -> Path:
"""
确保目录存在,如果不存在则创建
Args:
path: 目录路径
Returns:
Path对象
"""
path_obj = Path(path) if isinstance(path, str) else path
path_obj.mkdir(parents=True, exist_ok=True)
return path_obj
def get_relative_path(base: Union[str, Path], target: Union[str, Path]) -> Path:
"""
获取相对于基路径的相对路径
Args:
base: 基路径
target: 目标路径
Returns:
相对路径的Path对象
"""
base_path = Path(base) if isinstance(base, str) else base
target_path = Path(target) if isinstance(target, str) else target
# 如果目标路径是绝对路径,直接返回
if target_path.is_absolute():
return target_path
# 否则相对于基路径
return (base_path / target_path).resolve()
def resolve_path(base: Union[str, Path], target: Union[str, Path]) -> Path:
"""
Resolve target path against base if target is relative.
Args:
base: base directory
target: target path (absolute or relative)
Returns:
Absolute Path
"""
base_path = Path(base) if isinstance(base, str) else base
target_path = Path(target) if isinstance(target, str) else target
if target_path.is_absolute():
return target_path
return (base_path / target_path).resolve()
def print_platform_summary():
"""打印平台摘要信息"""
info = get_platform_info()
print("=" * 50)
print("平台信息:")
print(f" 系统: {info['system']} {info['release']}")
print(f" 处理器: {info['processor']}")
print(f" Python: {info['python_version']}")
print(f" 当前目录: {info['current_dir']}")
if TORCH_AVAILABLE:
print(f" PyTorch: {torch.__version__}")
print(f" CUDA可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f" GPU数量: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
else:
print(" PyTorch: 未安装")
print("=" * 50)
if __name__ == "__main__":
# 测试代码
print_platform_summary()
print("\n设备检测测试:")
for device in ["auto", "cpu", "cuda", "mps"]:
try:
result = resolve_device(device, verbose=True)
print(f" 输入: '{device}' -> 输出: '{result}'")
except Exception as e:
print(f" 输入: '{device}' -> 错误: {e}")
print("\n路径处理测试:")
test_path = "some/path/to/file.txt"
print(f" 原始路径: {test_path}")
print(f" 安全路径: {safe_path(test_path)}")