260 lines
7.4 KiB
Python
260 lines
7.4 KiB
Python
#!/usr/bin/env python3
|
||
"""测试所有跨平台修改"""
|
||
|
||
import os
|
||
import sys
|
||
import importlib
|
||
from pathlib import Path
|
||
|
||
def test_imports():
|
||
"""测试所有模块导入"""
|
||
print("=== 测试模块导入 ===")
|
||
|
||
try:
|
||
import torch
|
||
has_torch = True
|
||
except Exception:
|
||
has_torch = False
|
||
|
||
try:
|
||
import matplotlib
|
||
has_matplotlib = True
|
||
except Exception:
|
||
has_matplotlib = False
|
||
|
||
modules_to_test = [
|
||
"platform_utils",
|
||
"data_utils",
|
||
"hybrid_diffusion",
|
||
"prepare_data",
|
||
"train",
|
||
"export_samples",
|
||
"sample",
|
||
"train_stub",
|
||
"evaluate_generated",
|
||
"plot_loss",
|
||
"analyze_hai21_03",
|
||
"run_pipeline",
|
||
]
|
||
|
||
all_success = True
|
||
for module_name in modules_to_test:
|
||
if not has_torch and module_name in {
|
||
"hybrid_diffusion",
|
||
"train",
|
||
"export_samples",
|
||
"sample",
|
||
"train_stub",
|
||
}:
|
||
print(f"↷ {module_name} 跳过(未安装 torch)")
|
||
continue
|
||
if not has_matplotlib and module_name in {"plot_loss"}:
|
||
print(f"↷ {module_name} 跳过(未安装 matplotlib)")
|
||
continue
|
||
try:
|
||
module = importlib.import_module(module_name)
|
||
print(f"✓ {module_name} 导入成功")
|
||
# 检查是否有明显的语法错误
|
||
if hasattr(module, '__file__'):
|
||
print(f" 路径: {module.__file__}")
|
||
except ImportError as e:
|
||
print(f"✗ {module_name} 导入失败: {e}")
|
||
all_success = False
|
||
except SyntaxError as e:
|
||
print(f"✗ {module_name} 语法错误: {e}")
|
||
all_success = False
|
||
except Exception as e:
|
||
print(f"✗ {module_name} 其他错误: {e}")
|
||
all_success = False
|
||
|
||
return all_success
|
||
|
||
def test_platform_utils():
|
||
"""测试平台工具函数"""
|
||
print("\n=== 测试 platform_utils ===")
|
||
|
||
try:
|
||
from platform_utils import (
|
||
get_platform_info,
|
||
is_windows,
|
||
is_linux,
|
||
is_macos,
|
||
resolve_device,
|
||
safe_path,
|
||
ensure_dir,
|
||
print_platform_summary
|
||
)
|
||
|
||
print("✓ 所有函数导入成功")
|
||
|
||
# 测试平台检测
|
||
info = get_platform_info()
|
||
print(f" 系统: {info['system']}")
|
||
print(f" Windows检测: {is_windows()}")
|
||
print(f" Linux检测: {is_linux()}")
|
||
print(f" macOS检测: {is_macos()}")
|
||
|
||
# 测试路径处理
|
||
test_path = "some/path/to/file.txt"
|
||
safe_result = safe_path(test_path)
|
||
print(f" 路径处理: '{test_path}' -> '{safe_result}'")
|
||
|
||
# 测试设备检测
|
||
print(" 设备检测测试:")
|
||
for device in ["auto", "cpu", "cuda"]:
|
||
try:
|
||
result = resolve_device(device, verbose=False)
|
||
print(f" '{device}' -> '{result}'")
|
||
except Exception as e:
|
||
print(f" '{device}' -> 错误: {e}")
|
||
|
||
return True
|
||
except Exception as e:
|
||
print(f"✗ platform_utils 测试失败: {e}")
|
||
return False
|
||
|
||
def test_path_handling():
|
||
"""测试路径处理"""
|
||
print("\n=== 测试路径处理 ===")
|
||
|
||
try:
|
||
from platform_utils import safe_path, ensure_dir
|
||
import tempfile
|
||
|
||
# 测试safe_path
|
||
test_cases = [
|
||
"simple/path",
|
||
"path/with\\mixed\\separators",
|
||
Path("pathlib/object"),
|
||
"C:\\Windows\\Path" if os.name == 'nt' else "/linux/path"
|
||
]
|
||
|
||
for test in test_cases:
|
||
result = safe_path(test)
|
||
print(f" safe_path('{test}') -> '{result}'")
|
||
|
||
# 测试ensure_dir
|
||
with tempfile.TemporaryDirectory() as tmpdir:
|
||
test_dir = Path(tmpdir) / "test" / "subdir"
|
||
ensure_dir(test_dir)
|
||
if test_dir.exists():
|
||
print(f" ensure_dir 成功: {test_dir}")
|
||
else:
|
||
print(f" ensure_dir 失败: {test_dir}")
|
||
|
||
return True
|
||
except Exception as e:
|
||
print(f"✗ 路径处理测试失败: {e}")
|
||
return False
|
||
|
||
def test_device_resolution():
|
||
"""测试设备解析"""
|
||
print("\n=== 测试设备解析 ===")
|
||
|
||
try:
|
||
from platform_utils import resolve_device
|
||
|
||
print("设备解析结果:")
|
||
for device_mode in ["auto", "cpu", "cuda"]:
|
||
try:
|
||
device = resolve_device(device_mode, verbose=True)
|
||
print(f" 模式 '{device_mode}': {device}")
|
||
except Exception as e:
|
||
print(f" 模式 '{device_mode}' 错误: {e}")
|
||
|
||
return True
|
||
except Exception as e:
|
||
print(f"✗ 设备解析测试失败: {e}")
|
||
return False
|
||
|
||
def check_file_modifications():
|
||
"""检查文件修改"""
|
||
print("\n=== 检查文件修改 ===")
|
||
|
||
files_to_check = [
|
||
"platform_utils.py",
|
||
"train.py",
|
||
"export_samples.py",
|
||
"sample.py",
|
||
"train_stub.py",
|
||
"run_pipeline.py",
|
||
"prepare_data.py",
|
||
]
|
||
|
||
all_exist = True
|
||
for filename in files_to_check:
|
||
filepath = Path(__file__).parent / filename
|
||
if filepath.exists():
|
||
print(f"✓ {filename} 存在")
|
||
# 检查文件大小
|
||
size = filepath.stat().st_size
|
||
print(f" 大小: {size} 字节")
|
||
else:
|
||
print(f"✗ {filename} 不存在")
|
||
all_exist = False
|
||
|
||
return all_exist
|
||
|
||
def main():
|
||
print("跨平台修改测试工具")
|
||
print("=" * 60)
|
||
|
||
# 打印平台信息
|
||
from platform_utils import print_platform_summary
|
||
print_platform_summary()
|
||
|
||
# 运行测试
|
||
tests = [
|
||
("文件修改检查", check_file_modifications),
|
||
("模块导入测试", test_imports),
|
||
("平台工具测试", test_platform_utils),
|
||
("路径处理测试", test_path_handling),
|
||
("设备解析测试", test_device_resolution),
|
||
]
|
||
|
||
results = []
|
||
for test_name, test_func in tests:
|
||
try:
|
||
success = test_func()
|
||
results.append((test_name, success))
|
||
except Exception as e:
|
||
print(f"{test_name} 测试异常: {e}")
|
||
results.append((test_name, False))
|
||
|
||
print("\n" + "=" * 60)
|
||
print("=== 测试结果汇总 ===")
|
||
|
||
all_passed = True
|
||
for test_name, success in results:
|
||
status = "✓ 通过" if success else "✗ 失败"
|
||
print(f"{test_name}: {status}")
|
||
if not success:
|
||
all_passed = False
|
||
|
||
print("\n" + "=" * 60)
|
||
print("=== 使用说明 ===")
|
||
|
||
if all_passed:
|
||
print("所有测试通过!代码现在应该可以在Windows和Linux上运行。")
|
||
print("\n运行完整流程:")
|
||
print(" python run_pipeline.py --device auto")
|
||
print("\n单独运行:")
|
||
print(" python prepare_data.py")
|
||
print(" python train.py --device auto")
|
||
print(" python export_samples.py --device auto --include-time")
|
||
print(" python evaluate_generated.py")
|
||
print(" python plot_loss.py")
|
||
else:
|
||
print("部分测试失败,需要进一步检查。")
|
||
|
||
print("\n设备选项:")
|
||
print(" --device auto : 自动检测最佳设备(推荐)")
|
||
print(" --device cpu : 强制使用CPU")
|
||
print(" --device cuda : 强制使用GPU(如果可用)")
|
||
|
||
return all_passed
|
||
|
||
if __name__ == "__main__":
|
||
success = main()
|
||
sys.exit(0 if success else 1)
|