From f2b447ac38392b54bc981f2092e57dc98d1190bd Mon Sep 17 00:00:00 2001 From: Mingzhe Yang Date: Wed, 4 Feb 2026 02:59:49 +0800 Subject: [PATCH] Fix config path resolution for one-click runners --- example/run_all.py | 38 ++++++++++++++++++++++++-------------- example/run_all_full.py | 34 ++++++++++++++++++++++------------ 2 files changed, 46 insertions(+), 26 deletions(-) diff --git a/example/run_all.py b/example/run_all.py index 5a7682a..9a9f602 100644 --- a/example/run_all.py +++ b/example/run_all.py @@ -7,7 +7,7 @@ import subprocess import sys from pathlib import Path -from platform_utils import safe_path, is_windows, resolve_path +from platform_utils import safe_path, is_windows def run(cmd): @@ -35,24 +35,34 @@ def parse_args(): return parser.parse_args() +def resolve_config_path(base_dir: Path, cfg_arg: str) -> Path: + p = Path(cfg_arg) + if p.is_absolute(): + if p.exists(): + return p.resolve() + raise SystemExit(f"config not found: {p}") + + repo_dir = base_dir.parent + candidates = [p, base_dir / p, repo_dir / p] + if p.parts and p.parts[0] == "example": + trimmed = Path(*p.parts[1:]) if len(p.parts) > 1 else Path() + if str(trimmed): + candidates.extend([base_dir / trimmed, repo_dir / trimmed]) + + for c in candidates: + if c.exists(): + return c.resolve() + + tried = "\n".join(str(c) for c in candidates) + raise SystemExit(f"config not found: {cfg_arg}\ntried:\n{tried}") + + def main(): args = parse_args() base_dir = Path(__file__).resolve().parent - config_path = Path(args.config) + config_path = resolve_config_path(base_dir, args.config) with open(config_path, "r", encoding="utf-8") as f: cfg = json.load(f) - - # Resolve config path without duplicating base_dir on Windows when user passes example/config.json - if config_path.is_absolute(): - config_path = resolve_path(config_path.parent, config_path) - else: - candidate = base_dir / config_path - if candidate.exists(): - config_path = resolve_path(candidate.parent, candidate) - elif config_path.exists(): - config_path = resolve_path(config_path.parent, config_path) - else: - config_path = resolve_path(base_dir, config_path) timesteps = cfg.get("timesteps", 200) seq_len = cfg.get("sample_seq_len", cfg.get("seq_len", 64)) batch_size = cfg.get("sample_batch_size", cfg.get("batch_size", 2)) diff --git a/example/run_all_full.py b/example/run_all_full.py index 9ebd722..62a51b7 100644 --- a/example/run_all_full.py +++ b/example/run_all_full.py @@ -7,7 +7,7 @@ import subprocess import sys from pathlib import Path -from platform_utils import safe_path, is_windows, resolve_path +from platform_utils import safe_path, is_windows def run(cmd): @@ -32,22 +32,32 @@ def parse_args(): return parser.parse_args() -def resolve_config(base_dir: Path, cfg_arg: str) -> Path: - config_path = Path(cfg_arg) - if config_path.is_absolute(): - return Path(resolve_path(config_path.parent, config_path)) - candidate = base_dir / config_path - if candidate.exists(): - return Path(resolve_path(candidate.parent, candidate)) - if config_path.exists(): - return Path(resolve_path(config_path.parent, config_path)) - return Path(resolve_path(base_dir, config_path)) +def resolve_config_path(base_dir: Path, cfg_arg: str) -> Path: + p = Path(cfg_arg) + if p.is_absolute(): + if p.exists(): + return p.resolve() + raise SystemExit(f"config not found: {p}") + + repo_dir = base_dir.parent + candidates = [p, base_dir / p, repo_dir / p] + if p.parts and p.parts[0] == "example": + trimmed = Path(*p.parts[1:]) if len(p.parts) > 1 else Path() + if str(trimmed): + candidates.extend([base_dir / trimmed, repo_dir / trimmed]) + + for c in candidates: + if c.exists(): + return c.resolve() + + tried = "\n".join(str(c) for c in candidates) + raise SystemExit(f"config not found: {cfg_arg}\ntried:\n{tried}") def main(): args = parse_args() base_dir = Path(__file__).resolve().parent - config_path = resolve_config(base_dir, args.config) + config_path = resolve_config_path(base_dir, args.config) with open(config_path, "r", encoding="utf-8") as f: cfg = json.load(f)