Win and linux can run the code
This commit is contained in:
132
example/check_gpu_setup.py
Normal file
132
example/check_gpu_setup.py
Normal file
@@ -0,0 +1,132 @@
|
||||
#!/usr/bin/env python3
|
||||
"""检查GPU设置和PyTorch安装"""
|
||||
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
def check_nvidia_gpu():
|
||||
"""检查是否有NVIDIA GPU"""
|
||||
print("=== 检查NVIDIA GPU ===")
|
||||
try:
|
||||
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
|
||||
if result.returncode == 0:
|
||||
print("✓ 找到NVIDIA GPU")
|
||||
print("输出:")
|
||||
print(result.stdout[:500]) # 只显示前500个字符
|
||||
return True
|
||||
else:
|
||||
print("✗ nvidia-smi命令失败")
|
||||
print("错误:", result.stderr)
|
||||
return False
|
||||
except FileNotFoundError:
|
||||
print("✗ nvidia-smi未找到,可能没有安装NVIDIA驱动")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"✗ 检查GPU时出错: {e}")
|
||||
return False
|
||||
|
||||
def check_pytorch_installation():
|
||||
"""检查PyTorch安装"""
|
||||
print("\n=== 检查PyTorch安装 ===")
|
||||
try:
|
||||
import torch
|
||||
print(f"PyTorch版本: {torch.__version__}")
|
||||
print(f"CUDA可用: {torch.cuda.is_available()}")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
print(f"CUDA版本: {torch.version.cuda}")
|
||||
print(f"GPU数量: {torch.cuda.device_count()}")
|
||||
for i in range(torch.cuda.device_count()):
|
||||
print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
|
||||
return True, "gpu"
|
||||
else:
|
||||
print("警告: 安装的是CPU版本的PyTorch")
|
||||
# 检查是否安装了GPU版本但CUDA不可用
|
||||
if '+cpu' in torch.__version__:
|
||||
print("确认: 安装的是明确的CPU版本 (包含'+cpu')")
|
||||
return True, "cpu"
|
||||
else:
|
||||
print("可能安装了GPU版本但CUDA驱动有问题")
|
||||
return True, "cuda_but_not_working"
|
||||
except ImportError:
|
||||
print("✗ PyTorch未安装")
|
||||
return False, "not_installed"
|
||||
except Exception as e:
|
||||
print(f"✗ 检查PyTorch时出错: {e}")
|
||||
return False, "error"
|
||||
|
||||
def get_installation_commands():
|
||||
"""获取安装命令"""
|
||||
print("\n=== 安装建议 ===")
|
||||
|
||||
import platform
|
||||
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
|
||||
|
||||
print(f"Python版本: {python_version}")
|
||||
print(f"操作系统: {platform.system()} {platform.release()}")
|
||||
|
||||
print("\n安装GPU版本的PyTorch:")
|
||||
print("1. 首先确保安装了NVIDIA驱动和CUDA工具包")
|
||||
print("2. 根据你的CUDA版本选择以下命令之一:")
|
||||
print()
|
||||
print("对于CUDA 12.1:")
|
||||
print(" pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121")
|
||||
print()
|
||||
print("对于CUDA 11.8:")
|
||||
print(" pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118")
|
||||
print()
|
||||
print("对于CPU版本:")
|
||||
print(" pip install torch torchvision torchaudio")
|
||||
print()
|
||||
print("要检查CUDA版本,可以运行: nvcc --version")
|
||||
|
||||
def check_current_setup():
|
||||
"""检查当前设置"""
|
||||
print("=== 当前设置检查 ===")
|
||||
|
||||
has_gpu = check_nvidia_gpu()
|
||||
pytorch_ok, pytorch_type = check_pytorch_installation()
|
||||
|
||||
print("\n=== 总结 ===")
|
||||
if has_gpu and pytorch_ok and pytorch_type == "gpu":
|
||||
print("✓ 完美!你有GPU并且PyTorch GPU版本已正确安装")
|
||||
print("可以运行: python run_pipeline.py --device cuda")
|
||||
elif has_gpu and pytorch_ok and pytorch_type == "cpu":
|
||||
print("⚠ 你有GPU但安装了CPU版本的PyTorch")
|
||||
print("建议安装GPU版本的PyTorch以获得更好的性能")
|
||||
get_installation_commands()
|
||||
print("\n暂时可以运行: python run_pipeline.py --device cpu")
|
||||
elif has_gpu and pytorch_ok and pytorch_type == "cuda_but_not_working":
|
||||
print("⚠ 可能有GPU和GPU版本的PyTorch,但CUDA不可用")
|
||||
print("检查CUDA驱动和PyTorch版本是否匹配")
|
||||
get_installation_commands()
|
||||
elif not has_gpu and pytorch_ok:
|
||||
print("ℹ 没有GPU,使用CPU版本的PyTorch")
|
||||
print("运行: python run_pipeline.py --device cpu")
|
||||
elif not pytorch_ok:
|
||||
print("✗ PyTorch未正确安装")
|
||||
get_installation_commands()
|
||||
|
||||
return has_gpu, pytorch_ok, pytorch_type
|
||||
|
||||
def main():
|
||||
print("GPU和PyTorch设置检查工具")
|
||||
print("=" * 50)
|
||||
|
||||
has_gpu, pytorch_ok, pytorch_type = check_current_setup()
|
||||
|
||||
print("\n=== 下一步建议 ===")
|
||||
if has_gpu and pytorch_type == "cpu":
|
||||
print("1. 考虑安装GPU版本的PyTorch")
|
||||
print("2. 或者继续使用CPU: python run_pipeline.py --device cpu")
|
||||
elif not has_gpu:
|
||||
print("1. 只能使用CPU运行")
|
||||
print("2. 运行: python run_pipeline.py --device cpu")
|
||||
else:
|
||||
print("1. 可以尝试运行: python run_pipeline.py --device cuda")
|
||||
print("2. 或者: python run_pipeline.py --device auto")
|
||||
|
||||
print("\n注意: 代码已修改,当指定--device cuda但CUDA不可用时,会自动回退到CPU")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user