Fix ablation config path resolution
This commit is contained in:
@@ -93,6 +93,32 @@ def write_json(path: Path, obj: Dict) -> None:
|
||||
json.dump(obj, f, indent=2)
|
||||
|
||||
|
||||
def absolutize_config_paths(cfg: Dict, base_config_path: Path) -> Dict:
|
||||
"""Freeze path-like config values so generated ablation configs remain runnable."""
|
||||
cfg = dict(cfg)
|
||||
path_keys = ["data_path", "split_path", "stats_path", "vocab_path", "out_dir"]
|
||||
glob_keys = ["data_glob"]
|
||||
|
||||
for key in path_keys:
|
||||
value = cfg.get(key)
|
||||
if not value:
|
||||
continue
|
||||
path = Path(value)
|
||||
if not path.is_absolute():
|
||||
cfg[key] = str((base_config_path.parent / path).resolve())
|
||||
|
||||
for key in glob_keys:
|
||||
value = cfg.get(key)
|
||||
if not value:
|
||||
continue
|
||||
path = Path(value)
|
||||
if not path.is_absolute():
|
||||
resolved = base_config_path.parent / path
|
||||
cfg[key] = str(resolved)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def selected_variants(arg: str) -> List[str]:
|
||||
if not arg:
|
||||
return list(DEFAULT_ABLATIONS.keys())
|
||||
@@ -153,7 +179,7 @@ def main():
|
||||
|
||||
generated_configs: Dict[str, Path] = {}
|
||||
for variant in variants:
|
||||
cfg = dict(base_cfg)
|
||||
cfg = absolutize_config_paths(base_cfg, config_path)
|
||||
cfg.update(DEFAULT_ABLATIONS[variant])
|
||||
cfg_path = out_root / "configs" / f"{variant}.json"
|
||||
write_json(cfg_path, cfg)
|
||||
|
||||
Reference in New Issue
Block a user