Files
mask-ddpm/example/check_gpu_setup.py
2026-01-22 17:39:31 +08:00

132 lines
4.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""检查GPU设置和PyTorch安装"""
import sys
import subprocess
def check_nvidia_gpu():
"""检查是否有NVIDIA GPU"""
print("=== 检查NVIDIA GPU ===")
try:
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
if result.returncode == 0:
print("✓ 找到NVIDIA GPU")
print("输出:")
print(result.stdout[:500]) # 只显示前500个字符
return True
else:
print("✗ nvidia-smi命令失败")
print("错误:", result.stderr)
return False
except FileNotFoundError:
print("✗ nvidia-smi未找到可能没有安装NVIDIA驱动")
return False
except Exception as e:
print(f"✗ 检查GPU时出错: {e}")
return False
def check_pytorch_installation():
"""检查PyTorch安装"""
print("\n=== 检查PyTorch安装 ===")
try:
import torch
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA版本: {torch.version.cuda}")
print(f"GPU数量: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
return True, "gpu"
else:
print("警告: 安装的是CPU版本的PyTorch")
# 检查是否安装了GPU版本但CUDA不可用
if '+cpu' in torch.__version__:
print("确认: 安装的是明确的CPU版本 (包含'+cpu')")
return True, "cpu"
else:
print("可能安装了GPU版本但CUDA驱动有问题")
return True, "cuda_but_not_working"
except ImportError:
print("✗ PyTorch未安装")
return False, "not_installed"
except Exception as e:
print(f"✗ 检查PyTorch时出错: {e}")
return False, "error"
def get_installation_commands():
"""获取安装命令"""
print("\n=== 安装建议 ===")
import platform
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
print(f"Python版本: {python_version}")
print(f"操作系统: {platform.system()} {platform.release()}")
print("\n安装GPU版本的PyTorch:")
print("1. 首先确保安装了NVIDIA驱动和CUDA工具包")
print("2. 根据你的CUDA版本选择以下命令之一:")
print()
print("对于CUDA 12.1:")
print(" pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121")
print()
print("对于CUDA 11.8:")
print(" pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118")
print()
print("对于CPU版本:")
print(" pip install torch torchvision torchaudio")
print()
print("要检查CUDA版本可以运行: nvcc --version")
def check_current_setup():
"""检查当前设置"""
print("=== 当前设置检查 ===")
has_gpu = check_nvidia_gpu()
pytorch_ok, pytorch_type = check_pytorch_installation()
print("\n=== 总结 ===")
if has_gpu and pytorch_ok and pytorch_type == "gpu":
print("✓ 完美你有GPU并且PyTorch GPU版本已正确安装")
print("可以运行: python run_pipeline.py --device cuda")
elif has_gpu and pytorch_ok and pytorch_type == "cpu":
print("⚠ 你有GPU但安装了CPU版本的PyTorch")
print("建议安装GPU版本的PyTorch以获得更好的性能")
get_installation_commands()
print("\n暂时可以运行: python run_pipeline.py --device cpu")
elif has_gpu and pytorch_ok and pytorch_type == "cuda_but_not_working":
print("⚠ 可能有GPU和GPU版本的PyTorch但CUDA不可用")
print("检查CUDA驱动和PyTorch版本是否匹配")
get_installation_commands()
elif not has_gpu and pytorch_ok:
print(" 没有GPU使用CPU版本的PyTorch")
print("运行: python run_pipeline.py --device cpu")
elif not pytorch_ok:
print("✗ PyTorch未正确安装")
get_installation_commands()
return has_gpu, pytorch_ok, pytorch_type
def main():
print("GPU和PyTorch设置检查工具")
print("=" * 50)
has_gpu, pytorch_ok, pytorch_type = check_current_setup()
print("\n=== 下一步建议 ===")
if has_gpu and pytorch_type == "cpu":
print("1. 考虑安装GPU版本的PyTorch")
print("2. 或者继续使用CPU: python run_pipeline.py --device cpu")
elif not has_gpu:
print("1. 只能使用CPU运行")
print("2. 运行: python run_pipeline.py --device cpu")
else:
print("1. 可以尝试运行: python run_pipeline.py --device cuda")
print("2. 或者: python run_pipeline.py --device auto")
print("\n注意: 代码已修改,当指定--device cuda但CUDA不可用时会自动回退到CPU")
if __name__ == "__main__":
main()