#!/usr/bin/env python3 """测试所有跨平台修改""" import os import sys import importlib from pathlib import Path def test_imports(): """测试所有模块导入""" print("=== 测试模块导入 ===") modules_to_test = [ "platform_utils", "data_utils", "hybrid_diffusion", "prepare_data", "train", "export_samples", "sample", "train_stub", "evaluate_generated", "plot_loss", "analyze_hai21_03", "run_pipeline", ] all_success = True for module_name in modules_to_test: try: module = importlib.import_module(module_name) print(f"✓ {module_name} 导入成功") # 检查是否有明显的语法错误 if hasattr(module, '__file__'): print(f" 路径: {module.__file__}") except ImportError as e: print(f"✗ {module_name} 导入失败: {e}") all_success = False except SyntaxError as e: print(f"✗ {module_name} 语法错误: {e}") all_success = False except Exception as e: print(f"✗ {module_name} 其他错误: {e}") all_success = False return all_success def test_platform_utils(): """测试平台工具函数""" print("\n=== 测试 platform_utils ===") try: from platform_utils import ( get_platform_info, is_windows, is_linux, is_macos, resolve_device, safe_path, ensure_dir, print_platform_summary ) print("✓ 所有函数导入成功") # 测试平台检测 info = get_platform_info() print(f" 系统: {info['system']}") print(f" Windows检测: {is_windows()}") print(f" Linux检测: {is_linux()}") print(f" macOS检测: {is_macos()}") # 测试路径处理 test_path = "some/path/to/file.txt" safe_result = safe_path(test_path) print(f" 路径处理: '{test_path}' -> '{safe_result}'") # 测试设备检测 print(" 设备检测测试:") for device in ["auto", "cpu", "cuda"]: try: result = resolve_device(device, verbose=False) print(f" '{device}' -> '{result}'") except Exception as e: print(f" '{device}' -> 错误: {e}") return True except Exception as e: print(f"✗ platform_utils 测试失败: {e}") return False def test_path_handling(): """测试路径处理""" print("\n=== 测试路径处理 ===") try: from platform_utils import safe_path, ensure_dir import tempfile # 测试safe_path test_cases = [ "simple/path", "path/with\\mixed\\separators", Path("pathlib/object"), "C:\\Windows\\Path" if os.name == 'nt' else "/linux/path" ] for test in test_cases: result = safe_path(test) print(f" safe_path('{test}') -> '{result}'") # 测试ensure_dir with tempfile.TemporaryDirectory() as tmpdir: test_dir = Path(tmpdir) / "test" / "subdir" ensure_dir(test_dir) if test_dir.exists(): print(f" ensure_dir 成功: {test_dir}") else: print(f" ensure_dir 失败: {test_dir}") return True except Exception as e: print(f"✗ 路径处理测试失败: {e}") return False def test_device_resolution(): """测试设备解析""" print("\n=== 测试设备解析 ===") try: from platform_utils import resolve_device print("设备解析结果:") for device_mode in ["auto", "cpu", "cuda"]: try: device = resolve_device(device_mode, verbose=True) print(f" 模式 '{device_mode}': {device}") except Exception as e: print(f" 模式 '{device_mode}' 错误: {e}") return True except Exception as e: print(f"✗ 设备解析测试失败: {e}") return False def check_file_modifications(): """检查文件修改""" print("\n=== 检查文件修改 ===") files_to_check = [ "platform_utils.py", "train.py", "export_samples.py", "sample.py", "train_stub.py", "run_pipeline.py", "prepare_data.py", ] all_exist = True for filename in files_to_check: filepath = Path(__file__).parent / filename if filepath.exists(): print(f"✓ {filename} 存在") # 检查文件大小 size = filepath.stat().st_size print(f" 大小: {size} 字节") else: print(f"✗ {filename} 不存在") all_exist = False return all_exist def main(): print("跨平台修改测试工具") print("=" * 60) # 打印平台信息 from platform_utils import print_platform_summary print_platform_summary() # 运行测试 tests = [ ("文件修改检查", check_file_modifications), ("模块导入测试", test_imports), ("平台工具测试", test_platform_utils), ("路径处理测试", test_path_handling), ("设备解析测试", test_device_resolution), ] results = [] for test_name, test_func in tests: try: success = test_func() results.append((test_name, success)) except Exception as e: print(f"{test_name} 测试异常: {e}") results.append((test_name, False)) print("\n" + "=" * 60) print("=== 测试结果汇总 ===") all_passed = True for test_name, success in results: status = "✓ 通过" if success else "✗ 失败" print(f"{test_name}: {status}") if not success: all_passed = False print("\n" + "=" * 60) print("=== 使用说明 ===") if all_passed: print("所有测试通过!代码现在应该可以在Windows和Linux上运行。") print("\n运行完整流程:") print(" python run_pipeline.py --device auto") print("\n单独运行:") print(" python prepare_data.py") print(" python train.py --device auto") print(" python export_samples.py --device auto --include-time") print(" python evaluate_generated.py") print(" python plot_loss.py") else: print("部分测试失败,需要进一步检查。") print("\n设备选项:") print(" --device auto : 自动检测最佳设备(推荐)") print(" --device cpu : 强制使用CPU") print(" --device cuda : 强制使用GPU(如果可用)") return all_passed if __name__ == "__main__": success = main() sys.exit(0 if success else 1)