update
This commit is contained in:
@@ -62,10 +62,14 @@ def main():
|
|||||||
# load base config, override split/stats/vocab/out_dir
|
# load base config, override split/stats/vocab/out_dir
|
||||||
cfg_path = Path(args.config)
|
cfg_path = Path(args.config)
|
||||||
cfg = json.loads(cfg_path.read_text(encoding="utf-8"))
|
cfg = json.loads(cfg_path.read_text(encoding="utf-8"))
|
||||||
|
repo_dir = base_dir.parent.parent
|
||||||
cfg["split_path"] = str(split_path)
|
cfg["split_path"] = str(split_path)
|
||||||
cfg["stats_path"] = str(results_dir / f"cont_stats_{tag}.json")
|
cfg["stats_path"] = str(results_dir / f"cont_stats_{tag}.json")
|
||||||
cfg["vocab_path"] = str(results_dir / f"disc_vocab_{tag}.json")
|
cfg["vocab_path"] = str(results_dir / f"disc_vocab_{tag}.json")
|
||||||
cfg["out_dir"] = str(results_dir / f"ablation_{tag}")
|
cfg["out_dir"] = str(results_dir / f"ablation_{tag}")
|
||||||
|
# ensure data paths are absolute for Windows
|
||||||
|
cfg["data_glob"] = str(Path(args.data_glob).resolve())
|
||||||
|
cfg["data_path"] = str((repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz").resolve())
|
||||||
|
|
||||||
temp_cfg = results_dir / f"config_{tag}.json"
|
temp_cfg = results_dir / f"config_{tag}.json"
|
||||||
temp_cfg.write_text(json.dumps(cfg, indent=2), encoding="utf-8")
|
temp_cfg.write_text(json.dumps(cfg, indent=2), encoding="utf-8")
|
||||||
|
|||||||
Reference in New Issue
Block a user