Fix ablation config path resolution
This commit is contained in:
@@ -93,6 +93,32 @@ def write_json(path: Path, obj: Dict) -> None:
|
|||||||
json.dump(obj, f, indent=2)
|
json.dump(obj, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def absolutize_config_paths(cfg: Dict, base_config_path: Path) -> Dict:
|
||||||
|
"""Freeze path-like config values so generated ablation configs remain runnable."""
|
||||||
|
cfg = dict(cfg)
|
||||||
|
path_keys = ["data_path", "split_path", "stats_path", "vocab_path", "out_dir"]
|
||||||
|
glob_keys = ["data_glob"]
|
||||||
|
|
||||||
|
for key in path_keys:
|
||||||
|
value = cfg.get(key)
|
||||||
|
if not value:
|
||||||
|
continue
|
||||||
|
path = Path(value)
|
||||||
|
if not path.is_absolute():
|
||||||
|
cfg[key] = str((base_config_path.parent / path).resolve())
|
||||||
|
|
||||||
|
for key in glob_keys:
|
||||||
|
value = cfg.get(key)
|
||||||
|
if not value:
|
||||||
|
continue
|
||||||
|
path = Path(value)
|
||||||
|
if not path.is_absolute():
|
||||||
|
resolved = base_config_path.parent / path
|
||||||
|
cfg[key] = str(resolved)
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
def selected_variants(arg: str) -> List[str]:
|
def selected_variants(arg: str) -> List[str]:
|
||||||
if not arg:
|
if not arg:
|
||||||
return list(DEFAULT_ABLATIONS.keys())
|
return list(DEFAULT_ABLATIONS.keys())
|
||||||
@@ -153,7 +179,7 @@ def main():
|
|||||||
|
|
||||||
generated_configs: Dict[str, Path] = {}
|
generated_configs: Dict[str, Path] = {}
|
||||||
for variant in variants:
|
for variant in variants:
|
||||||
cfg = dict(base_cfg)
|
cfg = absolutize_config_paths(base_cfg, config_path)
|
||||||
cfg.update(DEFAULT_ABLATIONS[variant])
|
cfg.update(DEFAULT_ABLATIONS[variant])
|
||||||
cfg_path = out_root / "configs" / f"{variant}.json"
|
cfg_path = out_root / "configs" / f"{variant}.json"
|
||||||
write_json(cfg_path, cfg)
|
write_json(cfg_path, cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user