This commit is contained in:
2026-01-23 13:07:16 +08:00
parent 5547e89287
commit 45752d6c97

View File

@@ -62,10 +62,14 @@ def main():
# load base config, override split/stats/vocab/out_dir # load base config, override split/stats/vocab/out_dir
cfg_path = Path(args.config) cfg_path = Path(args.config)
cfg = json.loads(cfg_path.read_text(encoding="utf-8")) cfg = json.loads(cfg_path.read_text(encoding="utf-8"))
repo_dir = base_dir.parent.parent
cfg["split_path"] = str(split_path) cfg["split_path"] = str(split_path)
cfg["stats_path"] = str(results_dir / f"cont_stats_{tag}.json") cfg["stats_path"] = str(results_dir / f"cont_stats_{tag}.json")
cfg["vocab_path"] = str(results_dir / f"disc_vocab_{tag}.json") cfg["vocab_path"] = str(results_dir / f"disc_vocab_{tag}.json")
cfg["out_dir"] = str(results_dir / f"ablation_{tag}") cfg["out_dir"] = str(results_dir / f"ablation_{tag}")
# ensure data paths are absolute for Windows
cfg["data_glob"] = str(Path(args.data_glob).resolve())
cfg["data_path"] = str((repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz").resolve())
temp_cfg = results_dir / f"config_{tag}.json" temp_cfg = results_dir / f"config_{tag}.json"
temp_cfg.write_text(json.dumps(cfg, indent=2), encoding="utf-8") temp_cfg.write_text(json.dumps(cfg, indent=2), encoding="utf-8")