Fix ablation config path resolution

This commit is contained in:
MZ YANG
2026-03-26 22:58:23 +08:00
parent 957b010ea1
commit 633e9fd18c

View File

@@ -93,6 +93,32 @@ def write_json(path: Path, obj: Dict) -> None:
json.dump(obj, f, indent=2) json.dump(obj, f, indent=2)
def absolutize_config_paths(cfg: Dict, base_config_path: Path) -> Dict:
"""Freeze path-like config values so generated ablation configs remain runnable."""
cfg = dict(cfg)
path_keys = ["data_path", "split_path", "stats_path", "vocab_path", "out_dir"]
glob_keys = ["data_glob"]
for key in path_keys:
value = cfg.get(key)
if not value:
continue
path = Path(value)
if not path.is_absolute():
cfg[key] = str((base_config_path.parent / path).resolve())
for key in glob_keys:
value = cfg.get(key)
if not value:
continue
path = Path(value)
if not path.is_absolute():
resolved = base_config_path.parent / path
cfg[key] = str(resolved)
return cfg
def selected_variants(arg: str) -> List[str]: def selected_variants(arg: str) -> List[str]:
if not arg: if not arg:
return list(DEFAULT_ABLATIONS.keys()) return list(DEFAULT_ABLATIONS.keys())
@@ -153,7 +179,7 @@ def main():
generated_configs: Dict[str, Path] = {} generated_configs: Dict[str, Path] = {}
for variant in variants: for variant in variants:
cfg = dict(base_cfg) cfg = absolutize_config_paths(base_cfg, config_path)
cfg.update(DEFAULT_ABLATIONS[variant]) cfg.update(DEFAULT_ABLATIONS[variant])
cfg_path = out_root / "configs" / f"{variant}.json" cfg_path = out_root / "configs" / f"{variant}.json"
write_json(cfg_path, cfg) write_json(cfg_path, cfg)