diff --git a/example/run_ablation.py b/example/run_ablation.py index e2a1792..f2b3e75 100644 --- a/example/run_ablation.py +++ b/example/run_ablation.py @@ -62,10 +62,14 @@ def main(): # load base config, override split/stats/vocab/out_dir cfg_path = Path(args.config) cfg = json.loads(cfg_path.read_text(encoding="utf-8")) + repo_dir = base_dir.parent.parent cfg["split_path"] = str(split_path) cfg["stats_path"] = str(results_dir / f"cont_stats_{tag}.json") cfg["vocab_path"] = str(results_dir / f"disc_vocab_{tag}.json") cfg["out_dir"] = str(results_dir / f"ablation_{tag}") + # ensure data paths are absolute for Windows + cfg["data_glob"] = str(Path(args.data_glob).resolve()) + cfg["data_path"] = str((repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz").resolve()) temp_cfg = results_dir / f"config_{tag}.json" temp_cfg.write_text(json.dumps(cfg, indent=2), encoding="utf-8")