diff --git a/example/train.py b/example/train.py index e265116..13699c7 100755 --- a/example/train.py +++ b/example/train.py @@ -18,7 +18,7 @@ from hybrid_diffusion import ( q_sample_continuous, q_sample_discrete, ) -from platform_utils import resolve_device, safe_path, ensure_dir +from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path BASE_DIR = Path(__file__).resolve().parent @@ -84,12 +84,17 @@ def resolve_config_paths(config, base_dir: Path): if key in config: # 如果值是字符串,转换为Path对象 if isinstance(config[key], str): - path = Path(config[key]) + path_str = config[key] + # glob pattern cannot be Path.resolve()'d on Windows + if "*" in path_str or "?" in path_str or "[" in path_str: + config[key] = str((base_dir / Path(path_str))) + continue + path = Path(path_str) else: path = config[key] if not path.is_absolute(): - config[key] = str((base_dir / path).resolve()) + config[key] = str(resolve_path(base_dir, path)) else: config[key] = str(path) return config