diff --git a/example/run_all.py b/example/run_all.py index 68991aa..7cf530a 100644 --- a/example/run_all.py +++ b/example/run_all.py @@ -38,7 +38,11 @@ def main(): with open(config_path, "r", encoding="utf-8") as f: cfg = json.load(f) - config_path = resolve_path(base_dir, config_path) + # Resolve config path without duplicating base_dir on Windows when user passes example/config.json + if config_path.exists(): + config_path = resolve_path(config_path.parent, config_path) + else: + config_path = resolve_path(base_dir, config_path) timesteps = cfg.get("timesteps", 200) seq_len = cfg.get("sample_seq_len", cfg.get("seq_len", 64)) batch_size = cfg.get("sample_batch_size", cfg.get("batch_size", 2))