132 lines
4.9 KiB
Python
132 lines
4.9 KiB
Python
#!/usr/bin/env python3
|
||
"""检查GPU设置和PyTorch安装"""
|
||
|
||
import sys
|
||
import subprocess
|
||
|
||
def check_nvidia_gpu():
|
||
"""检查是否有NVIDIA GPU"""
|
||
print("=== 检查NVIDIA GPU ===")
|
||
try:
|
||
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
|
||
if result.returncode == 0:
|
||
print("✓ 找到NVIDIA GPU")
|
||
print("输出:")
|
||
print(result.stdout[:500]) # 只显示前500个字符
|
||
return True
|
||
else:
|
||
print("✗ nvidia-smi命令失败")
|
||
print("错误:", result.stderr)
|
||
return False
|
||
except FileNotFoundError:
|
||
print("✗ nvidia-smi未找到,可能没有安装NVIDIA驱动")
|
||
return False
|
||
except Exception as e:
|
||
print(f"✗ 检查GPU时出错: {e}")
|
||
return False
|
||
|
||
def check_pytorch_installation():
|
||
"""检查PyTorch安装"""
|
||
print("\n=== 检查PyTorch安装 ===")
|
||
try:
|
||
import torch
|
||
print(f"PyTorch版本: {torch.__version__}")
|
||
print(f"CUDA可用: {torch.cuda.is_available()}")
|
||
|
||
if torch.cuda.is_available():
|
||
print(f"CUDA版本: {torch.version.cuda}")
|
||
print(f"GPU数量: {torch.cuda.device_count()}")
|
||
for i in range(torch.cuda.device_count()):
|
||
print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
|
||
return True, "gpu"
|
||
else:
|
||
print("警告: 安装的是CPU版本的PyTorch")
|
||
# 检查是否安装了GPU版本但CUDA不可用
|
||
if '+cpu' in torch.__version__:
|
||
print("确认: 安装的是明确的CPU版本 (包含'+cpu')")
|
||
return True, "cpu"
|
||
else:
|
||
print("可能安装了GPU版本但CUDA驱动有问题")
|
||
return True, "cuda_but_not_working"
|
||
except ImportError:
|
||
print("✗ PyTorch未安装")
|
||
return False, "not_installed"
|
||
except Exception as e:
|
||
print(f"✗ 检查PyTorch时出错: {e}")
|
||
return False, "error"
|
||
|
||
def get_installation_commands():
|
||
"""获取安装命令"""
|
||
print("\n=== 安装建议 ===")
|
||
|
||
import platform
|
||
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
|
||
|
||
print(f"Python版本: {python_version}")
|
||
print(f"操作系统: {platform.system()} {platform.release()}")
|
||
|
||
print("\n安装GPU版本的PyTorch:")
|
||
print("1. 首先确保安装了NVIDIA驱动和CUDA工具包")
|
||
print("2. 根据你的CUDA版本选择以下命令之一:")
|
||
print()
|
||
print("对于CUDA 12.1:")
|
||
print(" pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121")
|
||
print()
|
||
print("对于CUDA 11.8:")
|
||
print(" pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118")
|
||
print()
|
||
print("对于CPU版本:")
|
||
print(" pip install torch torchvision torchaudio")
|
||
print()
|
||
print("要检查CUDA版本,可以运行: nvcc --version")
|
||
|
||
def check_current_setup():
|
||
"""检查当前设置"""
|
||
print("=== 当前设置检查 ===")
|
||
|
||
has_gpu = check_nvidia_gpu()
|
||
pytorch_ok, pytorch_type = check_pytorch_installation()
|
||
|
||
print("\n=== 总结 ===")
|
||
if has_gpu and pytorch_ok and pytorch_type == "gpu":
|
||
print("✓ 完美!你有GPU并且PyTorch GPU版本已正确安装")
|
||
print("可以运行: python run_pipeline.py --device cuda")
|
||
elif has_gpu and pytorch_ok and pytorch_type == "cpu":
|
||
print("⚠ 你有GPU但安装了CPU版本的PyTorch")
|
||
print("建议安装GPU版本的PyTorch以获得更好的性能")
|
||
get_installation_commands()
|
||
print("\n暂时可以运行: python run_pipeline.py --device cpu")
|
||
elif has_gpu and pytorch_ok and pytorch_type == "cuda_but_not_working":
|
||
print("⚠ 可能有GPU和GPU版本的PyTorch,但CUDA不可用")
|
||
print("检查CUDA驱动和PyTorch版本是否匹配")
|
||
get_installation_commands()
|
||
elif not has_gpu and pytorch_ok:
|
||
print("ℹ 没有GPU,使用CPU版本的PyTorch")
|
||
print("运行: python run_pipeline.py --device cpu")
|
||
elif not pytorch_ok:
|
||
print("✗ PyTorch未正确安装")
|
||
get_installation_commands()
|
||
|
||
return has_gpu, pytorch_ok, pytorch_type
|
||
|
||
def main():
|
||
print("GPU和PyTorch设置检查工具")
|
||
print("=" * 50)
|
||
|
||
has_gpu, pytorch_ok, pytorch_type = check_current_setup()
|
||
|
||
print("\n=== 下一步建议 ===")
|
||
if has_gpu and pytorch_type == "cpu":
|
||
print("1. 考虑安装GPU版本的PyTorch")
|
||
print("2. 或者继续使用CPU: python run_pipeline.py --device cpu")
|
||
elif not has_gpu:
|
||
print("1. 只能使用CPU运行")
|
||
print("2. 运行: python run_pipeline.py --device cpu")
|
||
else:
|
||
print("1. 可以尝试运行: python run_pipeline.py --device cuda")
|
||
print("2. 或者: python run_pipeline.py --device auto")
|
||
|
||
print("\n注意: 代码已修改,当指定--device cuda但CUDA不可用时,会自动回退到CPU")
|
||
|
||
if __name__ == "__main__":
|
||
main() |