Win and linux can run the code
This commit is contained in:
235
example/test_all_modifications.py
Normal file
235
example/test_all_modifications.py
Normal file
@@ -0,0 +1,235 @@
|
||||
#!/usr/bin/env python3
|
||||
"""测试所有跨平台修改"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
|
||||
def test_imports():
|
||||
"""测试所有模块导入"""
|
||||
print("=== 测试模块导入 ===")
|
||||
|
||||
modules_to_test = [
|
||||
"platform_utils",
|
||||
"data_utils",
|
||||
"hybrid_diffusion",
|
||||
"prepare_data",
|
||||
"train",
|
||||
"export_samples",
|
||||
"sample",
|
||||
"train_stub",
|
||||
"evaluate_generated",
|
||||
"plot_loss",
|
||||
"analyze_hai21_03",
|
||||
"run_pipeline",
|
||||
]
|
||||
|
||||
all_success = True
|
||||
for module_name in modules_to_test:
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
print(f"✓ {module_name} 导入成功")
|
||||
# 检查是否有明显的语法错误
|
||||
if hasattr(module, '__file__'):
|
||||
print(f" 路径: {module.__file__}")
|
||||
except ImportError as e:
|
||||
print(f"✗ {module_name} 导入失败: {e}")
|
||||
all_success = False
|
||||
except SyntaxError as e:
|
||||
print(f"✗ {module_name} 语法错误: {e}")
|
||||
all_success = False
|
||||
except Exception as e:
|
||||
print(f"✗ {module_name} 其他错误: {e}")
|
||||
all_success = False
|
||||
|
||||
return all_success
|
||||
|
||||
def test_platform_utils():
|
||||
"""测试平台工具函数"""
|
||||
print("\n=== 测试 platform_utils ===")
|
||||
|
||||
try:
|
||||
from platform_utils import (
|
||||
get_platform_info,
|
||||
is_windows,
|
||||
is_linux,
|
||||
is_macos,
|
||||
resolve_device,
|
||||
safe_path,
|
||||
ensure_dir,
|
||||
print_platform_summary
|
||||
)
|
||||
|
||||
print("✓ 所有函数导入成功")
|
||||
|
||||
# 测试平台检测
|
||||
info = get_platform_info()
|
||||
print(f" 系统: {info['system']}")
|
||||
print(f" Windows检测: {is_windows()}")
|
||||
print(f" Linux检测: {is_linux()}")
|
||||
print(f" macOS检测: {is_macos()}")
|
||||
|
||||
# 测试路径处理
|
||||
test_path = "some/path/to/file.txt"
|
||||
safe_result = safe_path(test_path)
|
||||
print(f" 路径处理: '{test_path}' -> '{safe_result}'")
|
||||
|
||||
# 测试设备检测
|
||||
print(" 设备检测测试:")
|
||||
for device in ["auto", "cpu", "cuda"]:
|
||||
try:
|
||||
result = resolve_device(device, verbose=False)
|
||||
print(f" '{device}' -> '{result}'")
|
||||
except Exception as e:
|
||||
print(f" '{device}' -> 错误: {e}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ platform_utils 测试失败: {e}")
|
||||
return False
|
||||
|
||||
def test_path_handling():
|
||||
"""测试路径处理"""
|
||||
print("\n=== 测试路径处理 ===")
|
||||
|
||||
try:
|
||||
from platform_utils import safe_path, ensure_dir
|
||||
import tempfile
|
||||
|
||||
# 测试safe_path
|
||||
test_cases = [
|
||||
"simple/path",
|
||||
"path/with\\mixed\\separators",
|
||||
Path("pathlib/object"),
|
||||
"C:\\Windows\\Path" if os.name == 'nt' else "/linux/path"
|
||||
]
|
||||
|
||||
for test in test_cases:
|
||||
result = safe_path(test)
|
||||
print(f" safe_path('{test}') -> '{result}'")
|
||||
|
||||
# 测试ensure_dir
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_dir = Path(tmpdir) / "test" / "subdir"
|
||||
ensure_dir(test_dir)
|
||||
if test_dir.exists():
|
||||
print(f" ensure_dir 成功: {test_dir}")
|
||||
else:
|
||||
print(f" ensure_dir 失败: {test_dir}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ 路径处理测试失败: {e}")
|
||||
return False
|
||||
|
||||
def test_device_resolution():
|
||||
"""测试设备解析"""
|
||||
print("\n=== 测试设备解析 ===")
|
||||
|
||||
try:
|
||||
from platform_utils import resolve_device
|
||||
|
||||
print("设备解析结果:")
|
||||
for device_mode in ["auto", "cpu", "cuda"]:
|
||||
try:
|
||||
device = resolve_device(device_mode, verbose=True)
|
||||
print(f" 模式 '{device_mode}': {device}")
|
||||
except Exception as e:
|
||||
print(f" 模式 '{device_mode}' 错误: {e}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ 设备解析测试失败: {e}")
|
||||
return False
|
||||
|
||||
def check_file_modifications():
|
||||
"""检查文件修改"""
|
||||
print("\n=== 检查文件修改 ===")
|
||||
|
||||
files_to_check = [
|
||||
"platform_utils.py",
|
||||
"train.py",
|
||||
"export_samples.py",
|
||||
"sample.py",
|
||||
"train_stub.py",
|
||||
"run_pipeline.py",
|
||||
"prepare_data.py",
|
||||
]
|
||||
|
||||
all_exist = True
|
||||
for filename in files_to_check:
|
||||
filepath = Path(__file__).parent / filename
|
||||
if filepath.exists():
|
||||
print(f"✓ {filename} 存在")
|
||||
# 检查文件大小
|
||||
size = filepath.stat().st_size
|
||||
print(f" 大小: {size} 字节")
|
||||
else:
|
||||
print(f"✗ {filename} 不存在")
|
||||
all_exist = False
|
||||
|
||||
return all_exist
|
||||
|
||||
def main():
|
||||
print("跨平台修改测试工具")
|
||||
print("=" * 60)
|
||||
|
||||
# 打印平台信息
|
||||
from platform_utils import print_platform_summary
|
||||
print_platform_summary()
|
||||
|
||||
# 运行测试
|
||||
tests = [
|
||||
("文件修改检查", check_file_modifications),
|
||||
("模块导入测试", test_imports),
|
||||
("平台工具测试", test_platform_utils),
|
||||
("路径处理测试", test_path_handling),
|
||||
("设备解析测试", test_device_resolution),
|
||||
]
|
||||
|
||||
results = []
|
||||
for test_name, test_func in tests:
|
||||
try:
|
||||
success = test_func()
|
||||
results.append((test_name, success))
|
||||
except Exception as e:
|
||||
print(f"{test_name} 测试异常: {e}")
|
||||
results.append((test_name, False))
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("=== 测试结果汇总 ===")
|
||||
|
||||
all_passed = True
|
||||
for test_name, success in results:
|
||||
status = "✓ 通过" if success else "✗ 失败"
|
||||
print(f"{test_name}: {status}")
|
||||
if not success:
|
||||
all_passed = False
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("=== 使用说明 ===")
|
||||
|
||||
if all_passed:
|
||||
print("所有测试通过!代码现在应该可以在Windows和Linux上运行。")
|
||||
print("\n运行完整流程:")
|
||||
print(" python run_pipeline.py --device auto")
|
||||
print("\n单独运行:")
|
||||
print(" python prepare_data.py")
|
||||
print(" python train.py --device auto")
|
||||
print(" python export_samples.py --device auto --include-time")
|
||||
print(" python evaluate_generated.py")
|
||||
print(" python plot_loss.py")
|
||||
else:
|
||||
print("部分测试失败,需要进一步检查。")
|
||||
|
||||
print("\n设备选项:")
|
||||
print(" --device auto : 自动检测最佳设备(推荐)")
|
||||
print(" --device cpu : 强制使用CPU")
|
||||
print(" --device cuda : 强制使用GPU(如果可用)")
|
||||
|
||||
return all_passed
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
Reference in New Issue
Block a user