#!/usr/bin/env python3 """Run a default ablation suite and summarize results.""" from __future__ import annotations import argparse import csv import json import subprocess import sys from pathlib import Path from typing import Dict, List from platform_utils import safe_path, is_windows DEFAULT_ABLATIONS = { "full": {}, "no_temporal": { "use_temporal_stage1": False, }, "no_quantile": { "use_quantile_transform": False, "cont_post_calibrate": False, "full_stats": False, }, "no_post_calibration": { "cont_post_calibrate": False, }, "no_file_condition": { "use_condition": False, }, "no_type_routing": { "type1_features": [], "type2_features": [], "type3_features": [], "type4_features": [], "type5_features": [], "type6_features": [], }, "no_snr_weight": { "snr_weighted_loss": False, }, "no_quantile_loss": { "quantile_loss_weight": 0.0, }, "no_residual_stat": { "residual_stat_weight": 0.0, }, "eps_target": { "cont_target": "eps", "cont_clamp_x0": 0.0, }, } def parse_args(): base_dir = Path(__file__).resolve().parent parser = argparse.ArgumentParser(description="Run ablation experiments.") parser.add_argument("--config", default=str(base_dir / "config.json")) parser.add_argument("--device", default="auto") parser.add_argument("--variants", default="", help="comma-separated variant names; empty uses defaults") parser.add_argument("--seeds", default="", help="comma-separated seeds; empty uses config seed") parser.add_argument("--out-root", default=str(base_dir / "results" / "ablations")) parser.add_argument("--skip-prepare", action="store_true") parser.add_argument("--skip-train", action="store_true") parser.add_argument("--skip-export", action="store_true") parser.add_argument("--skip-eval", action="store_true") parser.add_argument("--skip-comprehensive-eval", action="store_true") parser.add_argument("--skip-postprocess", action="store_true") parser.add_argument("--skip-post-eval", action="store_true") parser.add_argument("--skip-diagnostics", action="store_true") return parser.parse_args() def run(cmd: List[str]) -> None: print("running:", " ".join(cmd)) cmd = [safe_path(arg) for arg in cmd] if is_windows(): subprocess.run(cmd, check=True, shell=False) else: subprocess.run(cmd, check=True) def load_json(path: Path) -> Dict: with path.open("r", encoding="utf-8") as f: return json.load(f) def write_json(path: Path, obj: Dict) -> None: path.parent.mkdir(parents=True, exist_ok=True) with path.open("w", encoding="utf-8") as f: json.dump(obj, f, indent=2) def absolutize_config_paths(cfg: Dict, base_config_path: Path) -> Dict: """Freeze path-like config values so generated ablation configs remain runnable.""" cfg = dict(cfg) path_keys = ["data_path", "split_path", "stats_path", "vocab_path", "out_dir"] glob_keys = ["data_glob"] for key in path_keys: value = cfg.get(key) if not value: continue path = Path(value) if not path.is_absolute(): cfg[key] = str((base_config_path.parent / path).resolve()) for key in glob_keys: value = cfg.get(key) if not value: continue path = Path(value) if not path.is_absolute(): resolved = base_config_path.parent / path cfg[key] = str(resolved) return cfg def selected_variants(arg: str) -> List[str]: if not arg: return list(DEFAULT_ABLATIONS.keys()) names = [name.strip() for name in arg.split(",") if name.strip()] unknown = [name for name in names if name not in DEFAULT_ABLATIONS] if unknown: raise SystemExit(f"unknown ablation names: {', '.join(unknown)}") return names def parse_seeds(arg: str, cfg: Dict) -> List[int]: if not arg: return [int(cfg.get("seed", 1337))] return [int(item.strip()) for item in arg.split(",") if item.strip()] def collect_metrics(run_dir: Path) -> Dict[str, float]: out: Dict[str, float] = {} eval_path = run_dir / "eval.json" if eval_path.exists(): data = load_json(eval_path) out["avg_ks"] = data.get("avg_ks") out["avg_jsd"] = data.get("avg_jsd") out["avg_lag1_diff"] = data.get("avg_lag1_diff") comp_path = run_dir / "comprehensive_eval.json" if comp_path.exists(): data = load_json(comp_path) out["continuous_mmd_rbf"] = data.get("two_sample", {}).get("continuous_mmd_rbf") out["discriminative_accuracy"] = data.get("two_sample", {}).get("discriminative_accuracy") out["corr_mean_abs_diff"] = data.get("coupling", {}).get("corr_mean_abs_diff") out["avg_psd_l1"] = data.get("frequency", {}).get("avg_psd_l1") out["memorization_ratio"] = data.get("diversity_privacy", {}).get("memorization_ratio") out["predictive_rmse_real"] = data.get("predictive_consistency", {}).get("real_only", {}).get("rmse") out["predictive_rmse_synth"] = data.get("predictive_consistency", {}).get("synthetic_only", {}).get("rmse") out["utility_auprc_real"] = data.get("anomaly_utility", {}).get("real_only", {}).get("auprc") out["utility_auprc_synth"] = data.get("anomaly_utility", {}).get("synthetic_only", {}).get("auprc") out["utility_auprc_aug"] = data.get("anomaly_utility", {}).get("real_plus_synthetic", {}).get("auprc") post_eval_path = run_dir / "eval_post.json" if post_eval_path.exists(): post = load_json(post_eval_path) out["post_avg_ks"] = post.get("avg_ks") out["post_avg_jsd"] = post.get("avg_jsd") out["post_avg_lag1_diff"] = post.get("avg_lag1_diff") return out def main(): args = parse_args() base_dir = Path(__file__).resolve().parent out_root = Path(args.out_root) out_root.mkdir(parents=True, exist_ok=True) config_path = Path(args.config) if not config_path.is_absolute(): config_path = (base_dir / config_path).resolve() base_cfg = load_json(config_path) variants = selected_variants(args.variants) seeds = parse_seeds(args.seeds, base_cfg) generated_configs: Dict[str, Path] = {} for variant in variants: cfg = absolutize_config_paths(base_cfg, config_path) cfg.update(DEFAULT_ABLATIONS[variant]) cfg_path = out_root / "configs" / f"{variant}.json" write_json(cfg_path, cfg) generated_configs[variant] = cfg_path history_path = out_root / "benchmark_history.csv" summary_path = out_root / "benchmark_summary.csv" runs_root = out_root / "runs" rows: List[Dict[str, object]] = [] for variant in variants: cfg_path = generated_configs[variant] cmd = [ sys.executable, str(base_dir / "run_all.py"), "--config", str(cfg_path), "--device", args.device, "--runs-root", str(runs_root), "--benchmark-history", str(history_path), "--benchmark-summary", str(summary_path), "--seeds", ",".join(str(seed) for seed in seeds), ] if args.skip_train: cmd.append("--skip-train") if args.skip_prepare: cmd.append("--skip-prepare") if args.skip_export: cmd.append("--skip-export") if args.skip_eval: cmd.append("--skip-eval") if args.skip_comprehensive_eval: cmd.append("--skip-comprehensive-eval") if args.skip_postprocess: cmd.append("--skip-postprocess") if args.skip_post_eval: cmd.append("--skip-post-eval") if args.skip_diagnostics: cmd.append("--skip-diagnostics") run(cmd) for seed in seeds: run_dir = runs_root / f"{cfg_path.stem}__seed{seed}" row: Dict[str, object] = {"variant": variant, "seed": seed, "run_dir": str(run_dir)} row.update(collect_metrics(run_dir)) rows.append(row) fieldnames = sorted({key for row in rows for key in row.keys()}) csv_path = out_root / "ablation_summary.csv" with csv_path.open("w", encoding="utf-8", newline="") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() for row in rows: writer.writerow(row) json_path = out_root / "ablation_summary.json" write_json(json_path, {"variants": variants, "seeds": seeds, "rows": rows}) print("wrote", csv_path) print("wrote", json_path) if __name__ == "__main__": main()