Win and linux can run the code

This commit is contained in:
MZ YANG
2026-01-22 17:39:31 +08:00
parent c3f750cd9d
commit f37a8ce179
22 changed files with 32572 additions and 87 deletions

View File

@@ -0,0 +1,235 @@
#!/usr/bin/env python3
"""测试所有跨平台修改"""
import os
import sys
import importlib
from pathlib import Path
def test_imports():
"""测试所有模块导入"""
print("=== 测试模块导入 ===")
modules_to_test = [
"platform_utils",
"data_utils",
"hybrid_diffusion",
"prepare_data",
"train",
"export_samples",
"sample",
"train_stub",
"evaluate_generated",
"plot_loss",
"analyze_hai21_03",
"run_pipeline",
]
all_success = True
for module_name in modules_to_test:
try:
module = importlib.import_module(module_name)
print(f"{module_name} 导入成功")
# 检查是否有明显的语法错误
if hasattr(module, '__file__'):
print(f" 路径: {module.__file__}")
except ImportError as e:
print(f"{module_name} 导入失败: {e}")
all_success = False
except SyntaxError as e:
print(f"{module_name} 语法错误: {e}")
all_success = False
except Exception as e:
print(f"{module_name} 其他错误: {e}")
all_success = False
return all_success
def test_platform_utils():
"""测试平台工具函数"""
print("\n=== 测试 platform_utils ===")
try:
from platform_utils import (
get_platform_info,
is_windows,
is_linux,
is_macos,
resolve_device,
safe_path,
ensure_dir,
print_platform_summary
)
print("✓ 所有函数导入成功")
# 测试平台检测
info = get_platform_info()
print(f" 系统: {info['system']}")
print(f" Windows检测: {is_windows()}")
print(f" Linux检测: {is_linux()}")
print(f" macOS检测: {is_macos()}")
# 测试路径处理
test_path = "some/path/to/file.txt"
safe_result = safe_path(test_path)
print(f" 路径处理: '{test_path}' -> '{safe_result}'")
# 测试设备检测
print(" 设备检测测试:")
for device in ["auto", "cpu", "cuda"]:
try:
result = resolve_device(device, verbose=False)
print(f" '{device}' -> '{result}'")
except Exception as e:
print(f" '{device}' -> 错误: {e}")
return True
except Exception as e:
print(f"✗ platform_utils 测试失败: {e}")
return False
def test_path_handling():
"""测试路径处理"""
print("\n=== 测试路径处理 ===")
try:
from platform_utils import safe_path, ensure_dir
import tempfile
# 测试safe_path
test_cases = [
"simple/path",
"path/with\\mixed\\separators",
Path("pathlib/object"),
"C:\\Windows\\Path" if os.name == 'nt' else "/linux/path"
]
for test in test_cases:
result = safe_path(test)
print(f" safe_path('{test}') -> '{result}'")
# 测试ensure_dir
with tempfile.TemporaryDirectory() as tmpdir:
test_dir = Path(tmpdir) / "test" / "subdir"
ensure_dir(test_dir)
if test_dir.exists():
print(f" ensure_dir 成功: {test_dir}")
else:
print(f" ensure_dir 失败: {test_dir}")
return True
except Exception as e:
print(f"✗ 路径处理测试失败: {e}")
return False
def test_device_resolution():
"""测试设备解析"""
print("\n=== 测试设备解析 ===")
try:
from platform_utils import resolve_device
print("设备解析结果:")
for device_mode in ["auto", "cpu", "cuda"]:
try:
device = resolve_device(device_mode, verbose=True)
print(f" 模式 '{device_mode}': {device}")
except Exception as e:
print(f" 模式 '{device_mode}' 错误: {e}")
return True
except Exception as e:
print(f"✗ 设备解析测试失败: {e}")
return False
def check_file_modifications():
"""检查文件修改"""
print("\n=== 检查文件修改 ===")
files_to_check = [
"platform_utils.py",
"train.py",
"export_samples.py",
"sample.py",
"train_stub.py",
"run_pipeline.py",
"prepare_data.py",
]
all_exist = True
for filename in files_to_check:
filepath = Path(__file__).parent / filename
if filepath.exists():
print(f"{filename} 存在")
# 检查文件大小
size = filepath.stat().st_size
print(f" 大小: {size} 字节")
else:
print(f"{filename} 不存在")
all_exist = False
return all_exist
def main():
print("跨平台修改测试工具")
print("=" * 60)
# 打印平台信息
from platform_utils import print_platform_summary
print_platform_summary()
# 运行测试
tests = [
("文件修改检查", check_file_modifications),
("模块导入测试", test_imports),
("平台工具测试", test_platform_utils),
("路径处理测试", test_path_handling),
("设备解析测试", test_device_resolution),
]
results = []
for test_name, test_func in tests:
try:
success = test_func()
results.append((test_name, success))
except Exception as e:
print(f"{test_name} 测试异常: {e}")
results.append((test_name, False))
print("\n" + "=" * 60)
print("=== 测试结果汇总 ===")
all_passed = True
for test_name, success in results:
status = "✓ 通过" if success else "✗ 失败"
print(f"{test_name}: {status}")
if not success:
all_passed = False
print("\n" + "=" * 60)
print("=== 使用说明 ===")
if all_passed:
print("所有测试通过代码现在应该可以在Windows和Linux上运行。")
print("\n运行完整流程:")
print(" python run_pipeline.py --device auto")
print("\n单独运行:")
print(" python prepare_data.py")
print(" python train.py --device auto")
print(" python export_samples.py --device auto --include-time")
print(" python evaluate_generated.py")
print(" python plot_loss.py")
else:
print("部分测试失败,需要进一步检查。")
print("\n设备选项:")
print(" --device auto : 自动检测最佳设备(推荐)")
print(" --device cpu : 强制使用CPU")
print(" --device cuda : 强制使用GPU如果可用")
return all_passed
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)