update
This commit is contained in:
@@ -10,6 +10,18 @@ def test_imports():
|
||||
"""测试所有模块导入"""
|
||||
print("=== 测试模块导入 ===")
|
||||
|
||||
try:
|
||||
import torch
|
||||
has_torch = True
|
||||
except Exception:
|
||||
has_torch = False
|
||||
|
||||
try:
|
||||
import matplotlib
|
||||
has_matplotlib = True
|
||||
except Exception:
|
||||
has_matplotlib = False
|
||||
|
||||
modules_to_test = [
|
||||
"platform_utils",
|
||||
"data_utils",
|
||||
@@ -27,6 +39,18 @@ def test_imports():
|
||||
|
||||
all_success = True
|
||||
for module_name in modules_to_test:
|
||||
if not has_torch and module_name in {
|
||||
"hybrid_diffusion",
|
||||
"train",
|
||||
"export_samples",
|
||||
"sample",
|
||||
"train_stub",
|
||||
}:
|
||||
print(f"↷ {module_name} 跳过(未安装 torch)")
|
||||
continue
|
||||
if not has_matplotlib and module_name in {"plot_loss"}:
|
||||
print(f"↷ {module_name} 跳过(未安装 matplotlib)")
|
||||
continue
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
print(f"✓ {module_name} 导入成功")
|
||||
@@ -232,4 +256,4 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
Reference in New Issue
Block a user