Files
mask-ddpm/example/find_correct_torch.py
2026-01-22 17:39:31 +08:00

75 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""查找正确的PyTorch安装命令"""
import sys
def check_cuda_compatibility():
"""检查CUDA兼容性"""
print("=== CUDA兼容性检查 ===")
print("你的CUDA版本: 12.3")
print("\nPyTorch官方支持的CUDA版本:")
print("1. CUDA 12.1 - 最新稳定版")
print("2. CUDA 11.8 - 广泛支持")
print("3. CUDA 11.7 - 旧版本支持")
print("\n注意: CUDA 12.3可能还没有官方预编译包")
def get_installation_options():
"""获取安装选项"""
print("\n=== 安装选项 ===")
print("选项1: 使用CUDA 12.1(向后兼容)")
print("大多数CUDA 12.x版本是兼容的")
print("命令: pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121")
print()
print("选项2: 使用conda安装推荐")
print("conda会自动处理CUDA兼容性")
print("命令: conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia")
print()
print("选项3: 从源码编译(高级用户)")
print("如果需要特定CUDA版本")
print()
print("选项4: 使用CPU版本")
print("如果没有GPU或不想处理兼容性问题")
print("命令: pip install torch torchvision torchaudio")
def check_current_torch():
"""检查当前torch安装"""
print("\n=== 当前PyTorch状态 ===")
try:
import torch
print(f"已安装版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
if '+cpu' in torch.__version__:
print("类型: CPU版本")
elif '+cu' in torch.__version__.lower():
print("类型: GPU版本")
cuda_ver = torch.__version__.split('+cu')[-1].split('+')[0]
print(f"编译的CUDA版本: {cuda_ver}")
else:
print("类型: 未知")
except ImportError:
print("PyTorch未安装")
def main():
print("PyTorch安装指南")
print("=" * 50)
check_cuda_compatibility()
check_current_torch()
get_installation_options()
print("\n=== 推荐方案 ===")
print("1. 首先尝试选项1CUDA 12.1")
print("2. 如果不行使用选项2conda安装")
print("3. 或者暂时使用CPU版本运行代码")
print("\n=== 验证安装 ===")
print("安装后运行: python -c \"import torch; print(f'版本: {torch.__version__}'); print(f'CUDA可用: {torch.cuda.is_available()}')\"")
if __name__ == "__main__":
main()