Files
mask-ddpm/example/test_gpu.py
2026-01-22 17:39:31 +08:00

128 lines
4.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""测试GPU可用性和配置"""
import torch
import sys
def test_gpu():
print("=== GPU测试 ===")
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA版本: {torch.version.cuda}")
print(f"GPU数量: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
print(f"\nGPU {i}:")
print(f" 名称: {torch.cuda.get_device_name(i)}")
print(f" 内存总量: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.2f} GB")
# 测试GPU张量
try:
x = torch.randn(3, 3).cuda(i)
y = torch.randn(3, 3).cuda(i)
z = torch.matmul(x, y)
print(f" GPU计算测试: 成功")
print(f" 当前内存使用: {torch.cuda.memory_allocated(i) / 1024**2:.2f} MB")
print(f" 最大内存使用: {torch.cuda.max_memory_allocated(i) / 1024**2:.2f} MB")
except Exception as e:
print(f" GPU计算测试: 失败 - {e}")
else:
print("警告: CUDA不可用将使用CPU运行")
return torch.cuda.is_available()
def test_device_resolution():
print("\n=== 设备解析测试 ===")
def resolve_device(mode: str) -> str:
mode = mode.lower()
if mode == "cpu":
return "cpu"
if mode == "cuda":
if not torch.cuda.is_available():
raise SystemExit("device set to cuda but CUDA is not available")
return "cuda"
if torch.cuda.is_available():
return "cuda"
return "cpu"
test_cases = [
("auto", "应该自动选择GPU如果可用"),
("cpu", "应该强制使用CPU"),
("cuda", "应该使用GPU如果可用"),
]
for mode, description in test_cases:
try:
device = resolve_device(mode)
print(f"模式 '{mode}': {device} - {description}")
except SystemExit as e:
print(f"模式 '{mode}': 错误 - {e}")
except Exception as e:
print(f"模式 '{mode}': 异常 - {e}")
def test_training_components():
print("\n=== 训练组件测试 ===")
# 测试是否能导入关键模块
try:
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
print("✓ hybrid_diffusion 模块导入成功")
# 创建一个简单的模型
model = HybridDiffusionModel(cont_dim=10, disc_vocab_sizes=[5, 3, 2])
if torch.cuda.is_available():
model = model.cuda()
print("✓ 模型可以移动到GPU")
else:
print("✓ 模型在CPU上")
# 测试前向传播
batch_size = 2
seq_len = 16
x_cont = torch.randn(batch_size, seq_len, 10)
x_disc = torch.randint(0, 5, (batch_size, seq_len, 3))
t = torch.randint(0, 100, (batch_size,))
if torch.cuda.is_available():
x_cont = x_cont.cuda()
x_disc = x_disc.cuda()
t = t.cuda()
model = model.cuda()
eps_pred, logits = model(x_cont, x_disc, t)
print(f"✓ 前向传播成功")
print(f" 连续输出形状: {eps_pred.shape}")
print(f" 离散输出数量: {len(logits)}")
except ImportError as e:
print(f"✗ 模块导入失败: {e}")
except Exception as e:
print(f"✗ 测试失败: {e}")
def main():
print("开始GPU和训练配置测试...")
gpu_available = test_gpu()
test_device_resolution()
test_training_components()
print("\n=== 使用建议 ===")
if gpu_available:
print("1. 使用GPU运行: python run_pipeline.py --device cuda")
print("2. 或自动选择: python run_pipeline.py --device auto")
print("3. 单独训练: python train.py --device cuda")
print("4. 单独采样: python export_samples.py --device cuda")
else:
print("1. 只能使用CPU运行: python run_pipeline.py --device cpu")
print("2. 检查CUDA和PyTorch安装")
print("3. 确保有NVIDIA GPU和正确的驱动程序")
return gpu_available
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)