Files
mask-ddpm/example/debug_cuda_issue.py
2026-01-22 17:39:31 +08:00

193 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""详细诊断CUDA问题"""
import sys
import os
import subprocess
def check_system_info():
"""检查系统信息"""
print("=== 系统信息 ===")
import platform
print(f"操作系统: {platform.system()} {platform.release()}")
print(f"Python版本: {sys.version}")
print(f"Python路径: {sys.executable}")
print(f"当前目录: {os.getcwd()}")
def check_nvidia_driver():
"""检查NVIDIA驱动"""
print("\n=== NVIDIA驱动检查 ===")
try:
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, encoding='utf-8', errors='ignore')
if result.returncode == 0:
print("✓ nvidia-smi 命令成功")
# 提取驱动版本
for line in result.stdout.split('\n'):
if 'Driver Version' in line:
print(f"驱动版本: {line.strip()}")
if 'CUDA Version' in line:
print(f"CUDA版本: {line.strip()}")
return True
else:
print("✗ nvidia-smi 命令失败")
print(f"错误: {result.stderr}")
return False
except FileNotFoundError:
print("✗ nvidia-smi 未找到")
return False
except Exception as e:
print(f"✗ 检查NVIDIA驱动时出错: {e}")
return False
def check_cuda_toolkit():
"""检查CUDA工具包"""
print("\n=== CUDA工具包检查 ===")
try:
result = subprocess.run(['nvcc', '--version'], capture_output=True, text=True, encoding='utf-8', errors='ignore')
if result.returncode == 0:
print("✓ nvcc 命令成功")
for line in result.stdout.split('\n'):
if 'release' in line.lower():
print(f"CUDA编译器: {line.strip()}")
return True
else:
print("✗ nvcc 命令失败")
return False
except FileNotFoundError:
print("✗ nvcc 未找到")
return False
except Exception as e:
print(f"✗ 检查CUDA工具包时出错: {e}")
return False
def check_pytorch_detailed():
"""详细检查PyTorch安装"""
print("\n=== PyTorch详细检查 ===")
try:
import torch
print(f"PyTorch版本: {torch.__version__}")
print(f"PyTorch路径: {torch.__file__}")
# 检查是否包含+cpu
if '+cpu' in torch.__version__:
print("⚠ PyTorch版本包含 '+cpu'这是CPU版本")
elif '+cu' in torch.__version__.lower():
cuda_version = torch.__version__.split('+cu')[-1].split('+')[0]
print(f"✓ PyTorch是GPU版本CUDA: {cuda_version}")
else:
print(" PyTorch版本信息不明确")
# 检查CUDA可用性
print(f"\nCUDA可用性检查:")
print(f" torch.cuda.is_available(): {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f" CUDA版本: {torch.version.cuda}")
print(f" GPU数量: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
print(f" 内存: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.2f} GB")
# 测试GPU计算
try:
x = torch.randn(3, 3).cuda()
y = torch.randn(3, 3).cuda()
z = torch.matmul(x, y)
print(f" GPU计算测试: ✓ 成功")
except Exception as e:
print(f" GPU计算测试: ✗ 失败 - {e}")
else:
print(" CUDA不可用原因可能是:")
print(" 1. 安装了CPU版本的PyTorch")
print(" 2. CUDA驱动不匹配")
print(" 3. 系统环境问题")
return torch.cuda.is_available()
except ImportError:
print("✗ PyTorch未安装")
return False
except Exception as e:
print(f"✗ 检查PyTorch时出错: {e}")
return False
def check_conda_environment():
"""检查conda环境"""
print("\n=== Conda环境检查 ===")
try:
result = subprocess.run(['conda', 'info', '--envs'], capture_output=True, text=True, encoding='utf-8', errors='ignore')
if result.returncode == 0:
print("当前conda环境:")
for line in result.stdout.split('\n'):
if '*' in line:
print(f" {line.strip()}")
else:
print("无法获取conda环境信息")
except Exception as e:
print(f"检查conda环境时出错: {e}")
def check_torch_installation():
"""检查torch安装详情"""
print("\n=== Torch安装详情 ===")
try:
import pkg_resources
packages = ['torch', 'torchvision', 'torchaudio']
for pkg in packages:
try:
dist = pkg_resources.get_distribution(pkg)
print(f"{pkg}: {dist.version} ({dist.location})")
except pkg_resources.DistributionNotFound:
print(f"{pkg}: 未安装")
except Exception as e:
print(f"检查包详情时出错: {e}")
def check_environment_variables():
"""检查环境变量"""
print("\n=== 环境变量检查 ===")
cuda_vars = ['CUDA_HOME', 'CUDA_PATH', 'PATH']
for var in cuda_vars:
value = os.environ.get(var, '未设置')
if var == 'PATH' and 'cuda' in value.lower():
print(f"{var}: 包含CUDA路径")
elif var != 'PATH':
print(f"{var}: {value}")
def main():
print("CUDA问题详细诊断工具")
print("=" * 60)
check_system_info()
check_nvidia_driver()
check_cuda_toolkit()
check_conda_environment()
check_environment_variables()
check_torch_installation()
cuda_available = check_pytorch_detailed()
print("\n" + "=" * 60)
print("=== 问题诊断与解决方案 ===")
if not cuda_available:
print("\n问题: PyTorch无法检测到CUDA")
print("\n可能的原因和解决方案:")
print("1. 安装了CPU版本的PyTorch")
print(" 解决方案: 重新安装GPU版本的PyTorch")
print(" 命令: pip uninstall torch torchvision torchaudio")
print(" pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121")
print()
print("2. CUDA版本不匹配")
print(" 你的CUDA版本: 12.3")
print(" 需要安装对应的PyTorch版本")
print(" 命令: pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu123")
print()
print("3. 环境变量问题")
print(" 确保CUDA_HOME或CUDA_PATH正确设置")
print()
print("4. 驱动程序问题")
print(" 更新NVIDIA驱动程序到最新版本")
else:
print("\n✓ CUDA可用可以正常使用GPU")
print("运行命令: python run_pipeline.py --device cuda")
if __name__ == "__main__":
main()