Files
mask-ddpm/example/test_all_modifications.py
2026-01-22 17:39:31 +08:00

235 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""测试所有跨平台修改"""
import os
import sys
import importlib
from pathlib import Path
def test_imports():
"""测试所有模块导入"""
print("=== 测试模块导入 ===")
modules_to_test = [
"platform_utils",
"data_utils",
"hybrid_diffusion",
"prepare_data",
"train",
"export_samples",
"sample",
"train_stub",
"evaluate_generated",
"plot_loss",
"analyze_hai21_03",
"run_pipeline",
]
all_success = True
for module_name in modules_to_test:
try:
module = importlib.import_module(module_name)
print(f"{module_name} 导入成功")
# 检查是否有明显的语法错误
if hasattr(module, '__file__'):
print(f" 路径: {module.__file__}")
except ImportError as e:
print(f"{module_name} 导入失败: {e}")
all_success = False
except SyntaxError as e:
print(f"{module_name} 语法错误: {e}")
all_success = False
except Exception as e:
print(f"{module_name} 其他错误: {e}")
all_success = False
return all_success
def test_platform_utils():
"""测试平台工具函数"""
print("\n=== 测试 platform_utils ===")
try:
from platform_utils import (
get_platform_info,
is_windows,
is_linux,
is_macos,
resolve_device,
safe_path,
ensure_dir,
print_platform_summary
)
print("✓ 所有函数导入成功")
# 测试平台检测
info = get_platform_info()
print(f" 系统: {info['system']}")
print(f" Windows检测: {is_windows()}")
print(f" Linux检测: {is_linux()}")
print(f" macOS检测: {is_macos()}")
# 测试路径处理
test_path = "some/path/to/file.txt"
safe_result = safe_path(test_path)
print(f" 路径处理: '{test_path}' -> '{safe_result}'")
# 测试设备检测
print(" 设备检测测试:")
for device in ["auto", "cpu", "cuda"]:
try:
result = resolve_device(device, verbose=False)
print(f" '{device}' -> '{result}'")
except Exception as e:
print(f" '{device}' -> 错误: {e}")
return True
except Exception as e:
print(f"✗ platform_utils 测试失败: {e}")
return False
def test_path_handling():
"""测试路径处理"""
print("\n=== 测试路径处理 ===")
try:
from platform_utils import safe_path, ensure_dir
import tempfile
# 测试safe_path
test_cases = [
"simple/path",
"path/with\\mixed\\separators",
Path("pathlib/object"),
"C:\\Windows\\Path" if os.name == 'nt' else "/linux/path"
]
for test in test_cases:
result = safe_path(test)
print(f" safe_path('{test}') -> '{result}'")
# 测试ensure_dir
with tempfile.TemporaryDirectory() as tmpdir:
test_dir = Path(tmpdir) / "test" / "subdir"
ensure_dir(test_dir)
if test_dir.exists():
print(f" ensure_dir 成功: {test_dir}")
else:
print(f" ensure_dir 失败: {test_dir}")
return True
except Exception as e:
print(f"✗ 路径处理测试失败: {e}")
return False
def test_device_resolution():
"""测试设备解析"""
print("\n=== 测试设备解析 ===")
try:
from platform_utils import resolve_device
print("设备解析结果:")
for device_mode in ["auto", "cpu", "cuda"]:
try:
device = resolve_device(device_mode, verbose=True)
print(f" 模式 '{device_mode}': {device}")
except Exception as e:
print(f" 模式 '{device_mode}' 错误: {e}")
return True
except Exception as e:
print(f"✗ 设备解析测试失败: {e}")
return False
def check_file_modifications():
"""检查文件修改"""
print("\n=== 检查文件修改 ===")
files_to_check = [
"platform_utils.py",
"train.py",
"export_samples.py",
"sample.py",
"train_stub.py",
"run_pipeline.py",
"prepare_data.py",
]
all_exist = True
for filename in files_to_check:
filepath = Path(__file__).parent / filename
if filepath.exists():
print(f"{filename} 存在")
# 检查文件大小
size = filepath.stat().st_size
print(f" 大小: {size} 字节")
else:
print(f"{filename} 不存在")
all_exist = False
return all_exist
def main():
print("跨平台修改测试工具")
print("=" * 60)
# 打印平台信息
from platform_utils import print_platform_summary
print_platform_summary()
# 运行测试
tests = [
("文件修改检查", check_file_modifications),
("模块导入测试", test_imports),
("平台工具测试", test_platform_utils),
("路径处理测试", test_path_handling),
("设备解析测试", test_device_resolution),
]
results = []
for test_name, test_func in tests:
try:
success = test_func()
results.append((test_name, success))
except Exception as e:
print(f"{test_name} 测试异常: {e}")
results.append((test_name, False))
print("\n" + "=" * 60)
print("=== 测试结果汇总 ===")
all_passed = True
for test_name, success in results:
status = "✓ 通过" if success else "✗ 失败"
print(f"{test_name}: {status}")
if not success:
all_passed = False
print("\n" + "=" * 60)
print("=== 使用说明 ===")
if all_passed:
print("所有测试通过代码现在应该可以在Windows和Linux上运行。")
print("\n运行完整流程:")
print(" python run_pipeline.py --device auto")
print("\n单独运行:")
print(" python prepare_data.py")
print(" python train.py --device auto")
print(" python export_samples.py --device auto --include-time")
print(" python evaluate_generated.py")
print(" python plot_loss.py")
else:
print("部分测试失败,需要进一步检查。")
print("\n设备选项:")
print(" --device auto : 自动检测最佳设备(推荐)")
print(" --device cpu : 强制使用CPU")
print(" --device cuda : 强制使用GPU如果可用")
return all_passed
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)