diff --git a/example/run_all.py b/example/run_all.py index 7cf530a..e0d1ed6 100644 --- a/example/run_all.py +++ b/example/run_all.py @@ -39,10 +39,16 @@ def main(): cfg = json.load(f) # Resolve config path without duplicating base_dir on Windows when user passes example/config.json - if config_path.exists(): + if config_path.is_absolute(): config_path = resolve_path(config_path.parent, config_path) else: - config_path = resolve_path(base_dir, config_path) + candidate = base_dir / config_path + if candidate.exists(): + config_path = resolve_path(candidate.parent, candidate) + elif config_path.exists(): + config_path = resolve_path(config_path.parent, config_path) + else: + config_path = resolve_path(base_dir, config_path) timesteps = cfg.get("timesteps", 200) seq_len = cfg.get("sample_seq_len", cfg.get("seq_len", 64)) batch_size = cfg.get("sample_batch_size", cfg.get("batch_size", 2))