Win and linux can run the code
This commit is contained in:
193
example/debug_cuda_issue.py
Normal file
193
example/debug_cuda_issue.py
Normal file
@@ -0,0 +1,193 @@
|
||||
#!/usr/bin/env python3
|
||||
"""详细诊断CUDA问题"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
def check_system_info():
|
||||
"""检查系统信息"""
|
||||
print("=== 系统信息 ===")
|
||||
import platform
|
||||
print(f"操作系统: {platform.system()} {platform.release()}")
|
||||
print(f"Python版本: {sys.version}")
|
||||
print(f"Python路径: {sys.executable}")
|
||||
print(f"当前目录: {os.getcwd()}")
|
||||
|
||||
def check_nvidia_driver():
|
||||
"""检查NVIDIA驱动"""
|
||||
print("\n=== NVIDIA驱动检查 ===")
|
||||
try:
|
||||
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, encoding='utf-8', errors='ignore')
|
||||
if result.returncode == 0:
|
||||
print("✓ nvidia-smi 命令成功")
|
||||
# 提取驱动版本
|
||||
for line in result.stdout.split('\n'):
|
||||
if 'Driver Version' in line:
|
||||
print(f"驱动版本: {line.strip()}")
|
||||
if 'CUDA Version' in line:
|
||||
print(f"CUDA版本: {line.strip()}")
|
||||
return True
|
||||
else:
|
||||
print("✗ nvidia-smi 命令失败")
|
||||
print(f"错误: {result.stderr}")
|
||||
return False
|
||||
except FileNotFoundError:
|
||||
print("✗ nvidia-smi 未找到")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"✗ 检查NVIDIA驱动时出错: {e}")
|
||||
return False
|
||||
|
||||
def check_cuda_toolkit():
|
||||
"""检查CUDA工具包"""
|
||||
print("\n=== CUDA工具包检查 ===")
|
||||
try:
|
||||
result = subprocess.run(['nvcc', '--version'], capture_output=True, text=True, encoding='utf-8', errors='ignore')
|
||||
if result.returncode == 0:
|
||||
print("✓ nvcc 命令成功")
|
||||
for line in result.stdout.split('\n'):
|
||||
if 'release' in line.lower():
|
||||
print(f"CUDA编译器: {line.strip()}")
|
||||
return True
|
||||
else:
|
||||
print("✗ nvcc 命令失败")
|
||||
return False
|
||||
except FileNotFoundError:
|
||||
print("✗ nvcc 未找到")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"✗ 检查CUDA工具包时出错: {e}")
|
||||
return False
|
||||
|
||||
def check_pytorch_detailed():
|
||||
"""详细检查PyTorch安装"""
|
||||
print("\n=== PyTorch详细检查 ===")
|
||||
try:
|
||||
import torch
|
||||
print(f"PyTorch版本: {torch.__version__}")
|
||||
print(f"PyTorch路径: {torch.__file__}")
|
||||
|
||||
# 检查是否包含+cpu
|
||||
if '+cpu' in torch.__version__:
|
||||
print("⚠ PyTorch版本包含 '+cpu',这是CPU版本")
|
||||
elif '+cu' in torch.__version__.lower():
|
||||
cuda_version = torch.__version__.split('+cu')[-1].split('+')[0]
|
||||
print(f"✓ PyTorch是GPU版本,CUDA: {cuda_version}")
|
||||
else:
|
||||
print("ℹ PyTorch版本信息不明确")
|
||||
|
||||
# 检查CUDA可用性
|
||||
print(f"\nCUDA可用性检查:")
|
||||
print(f" torch.cuda.is_available(): {torch.cuda.is_available()}")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
print(f" CUDA版本: {torch.version.cuda}")
|
||||
print(f" GPU数量: {torch.cuda.device_count()}")
|
||||
for i in range(torch.cuda.device_count()):
|
||||
print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
|
||||
print(f" 内存: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.2f} GB")
|
||||
|
||||
# 测试GPU计算
|
||||
try:
|
||||
x = torch.randn(3, 3).cuda()
|
||||
y = torch.randn(3, 3).cuda()
|
||||
z = torch.matmul(x, y)
|
||||
print(f" GPU计算测试: ✓ 成功")
|
||||
except Exception as e:
|
||||
print(f" GPU计算测试: ✗ 失败 - {e}")
|
||||
else:
|
||||
print(" CUDA不可用,原因可能是:")
|
||||
print(" 1. 安装了CPU版本的PyTorch")
|
||||
print(" 2. CUDA驱动不匹配")
|
||||
print(" 3. 系统环境问题")
|
||||
|
||||
return torch.cuda.is_available()
|
||||
|
||||
except ImportError:
|
||||
print("✗ PyTorch未安装")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"✗ 检查PyTorch时出错: {e}")
|
||||
return False
|
||||
|
||||
def check_conda_environment():
|
||||
"""检查conda环境"""
|
||||
print("\n=== Conda环境检查 ===")
|
||||
try:
|
||||
result = subprocess.run(['conda', 'info', '--envs'], capture_output=True, text=True, encoding='utf-8', errors='ignore')
|
||||
if result.returncode == 0:
|
||||
print("当前conda环境:")
|
||||
for line in result.stdout.split('\n'):
|
||||
if '*' in line:
|
||||
print(f" {line.strip()}")
|
||||
else:
|
||||
print("无法获取conda环境信息")
|
||||
except Exception as e:
|
||||
print(f"检查conda环境时出错: {e}")
|
||||
|
||||
def check_torch_installation():
|
||||
"""检查torch安装详情"""
|
||||
print("\n=== Torch安装详情 ===")
|
||||
try:
|
||||
import pkg_resources
|
||||
packages = ['torch', 'torchvision', 'torchaudio']
|
||||
for pkg in packages:
|
||||
try:
|
||||
dist = pkg_resources.get_distribution(pkg)
|
||||
print(f"{pkg}: {dist.version} ({dist.location})")
|
||||
except pkg_resources.DistributionNotFound:
|
||||
print(f"{pkg}: 未安装")
|
||||
except Exception as e:
|
||||
print(f"检查包详情时出错: {e}")
|
||||
|
||||
def check_environment_variables():
|
||||
"""检查环境变量"""
|
||||
print("\n=== 环境变量检查 ===")
|
||||
cuda_vars = ['CUDA_HOME', 'CUDA_PATH', 'PATH']
|
||||
for var in cuda_vars:
|
||||
value = os.environ.get(var, '未设置')
|
||||
if var == 'PATH' and 'cuda' in value.lower():
|
||||
print(f"{var}: 包含CUDA路径")
|
||||
elif var != 'PATH':
|
||||
print(f"{var}: {value}")
|
||||
|
||||
def main():
|
||||
print("CUDA问题详细诊断工具")
|
||||
print("=" * 60)
|
||||
|
||||
check_system_info()
|
||||
check_nvidia_driver()
|
||||
check_cuda_toolkit()
|
||||
check_conda_environment()
|
||||
check_environment_variables()
|
||||
check_torch_installation()
|
||||
cuda_available = check_pytorch_detailed()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("=== 问题诊断与解决方案 ===")
|
||||
|
||||
if not cuda_available:
|
||||
print("\n问题: PyTorch无法检测到CUDA")
|
||||
print("\n可能的原因和解决方案:")
|
||||
print("1. 安装了CPU版本的PyTorch")
|
||||
print(" 解决方案: 重新安装GPU版本的PyTorch")
|
||||
print(" 命令: pip uninstall torch torchvision torchaudio")
|
||||
print(" pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121")
|
||||
print()
|
||||
print("2. CUDA版本不匹配")
|
||||
print(" 你的CUDA版本: 12.3")
|
||||
print(" 需要安装对应的PyTorch版本")
|
||||
print(" 命令: pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu123")
|
||||
print()
|
||||
print("3. 环境变量问题")
|
||||
print(" 确保CUDA_HOME或CUDA_PATH正确设置")
|
||||
print()
|
||||
print("4. 驱动程序问题")
|
||||
print(" 更新NVIDIA驱动程序到最新版本")
|
||||
else:
|
||||
print("\n✓ CUDA可用,可以正常使用GPU")
|
||||
print("运行命令: python run_pipeline.py --device cuda")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user