Files
mask-ddpm/example/run_ablations.py
2026-03-26 22:58:23 +08:00

251 lines
8.6 KiB
Python

#!/usr/bin/env python3
"""Run a default ablation suite and summarize results."""
from __future__ import annotations
import argparse
import csv
import json
import subprocess
import sys
from pathlib import Path
from typing import Dict, List
from platform_utils import safe_path, is_windows
DEFAULT_ABLATIONS = {
"full": {},
"no_temporal": {
"use_temporal_stage1": False,
},
"no_quantile": {
"use_quantile_transform": False,
"cont_post_calibrate": False,
"full_stats": False,
},
"no_post_calibration": {
"cont_post_calibrate": False,
},
"no_file_condition": {
"use_condition": False,
},
"no_type_routing": {
"type1_features": [],
"type2_features": [],
"type3_features": [],
"type4_features": [],
"type5_features": [],
"type6_features": [],
},
"no_snr_weight": {
"snr_weighted_loss": False,
},
"no_quantile_loss": {
"quantile_loss_weight": 0.0,
},
"no_residual_stat": {
"residual_stat_weight": 0.0,
},
"eps_target": {
"cont_target": "eps",
"cont_clamp_x0": 0.0,
},
}
def parse_args():
base_dir = Path(__file__).resolve().parent
parser = argparse.ArgumentParser(description="Run ablation experiments.")
parser.add_argument("--config", default=str(base_dir / "config.json"))
parser.add_argument("--device", default="auto")
parser.add_argument("--variants", default="", help="comma-separated variant names; empty uses defaults")
parser.add_argument("--seeds", default="", help="comma-separated seeds; empty uses config seed")
parser.add_argument("--out-root", default=str(base_dir / "results" / "ablations"))
parser.add_argument("--skip-prepare", action="store_true")
parser.add_argument("--skip-train", action="store_true")
parser.add_argument("--skip-export", action="store_true")
parser.add_argument("--skip-eval", action="store_true")
parser.add_argument("--skip-comprehensive-eval", action="store_true")
parser.add_argument("--skip-postprocess", action="store_true")
parser.add_argument("--skip-post-eval", action="store_true")
parser.add_argument("--skip-diagnostics", action="store_true")
return parser.parse_args()
def run(cmd: List[str]) -> None:
print("running:", " ".join(cmd))
cmd = [safe_path(arg) for arg in cmd]
if is_windows():
subprocess.run(cmd, check=True, shell=False)
else:
subprocess.run(cmd, check=True)
def load_json(path: Path) -> Dict:
with path.open("r", encoding="utf-8") as f:
return json.load(f)
def write_json(path: Path, obj: Dict) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as f:
json.dump(obj, f, indent=2)
def absolutize_config_paths(cfg: Dict, base_config_path: Path) -> Dict:
"""Freeze path-like config values so generated ablation configs remain runnable."""
cfg = dict(cfg)
path_keys = ["data_path", "split_path", "stats_path", "vocab_path", "out_dir"]
glob_keys = ["data_glob"]
for key in path_keys:
value = cfg.get(key)
if not value:
continue
path = Path(value)
if not path.is_absolute():
cfg[key] = str((base_config_path.parent / path).resolve())
for key in glob_keys:
value = cfg.get(key)
if not value:
continue
path = Path(value)
if not path.is_absolute():
resolved = base_config_path.parent / path
cfg[key] = str(resolved)
return cfg
def selected_variants(arg: str) -> List[str]:
if not arg:
return list(DEFAULT_ABLATIONS.keys())
names = [name.strip() for name in arg.split(",") if name.strip()]
unknown = [name for name in names if name not in DEFAULT_ABLATIONS]
if unknown:
raise SystemExit(f"unknown ablation names: {', '.join(unknown)}")
return names
def parse_seeds(arg: str, cfg: Dict) -> List[int]:
if not arg:
return [int(cfg.get("seed", 1337))]
return [int(item.strip()) for item in arg.split(",") if item.strip()]
def collect_metrics(run_dir: Path) -> Dict[str, float]:
out: Dict[str, float] = {}
eval_path = run_dir / "eval.json"
if eval_path.exists():
data = load_json(eval_path)
out["avg_ks"] = data.get("avg_ks")
out["avg_jsd"] = data.get("avg_jsd")
out["avg_lag1_diff"] = data.get("avg_lag1_diff")
comp_path = run_dir / "comprehensive_eval.json"
if comp_path.exists():
data = load_json(comp_path)
out["continuous_mmd_rbf"] = data.get("two_sample", {}).get("continuous_mmd_rbf")
out["discriminative_accuracy"] = data.get("two_sample", {}).get("discriminative_accuracy")
out["corr_mean_abs_diff"] = data.get("coupling", {}).get("corr_mean_abs_diff")
out["avg_psd_l1"] = data.get("frequency", {}).get("avg_psd_l1")
out["memorization_ratio"] = data.get("diversity_privacy", {}).get("memorization_ratio")
out["predictive_rmse_real"] = data.get("predictive_consistency", {}).get("real_only", {}).get("rmse")
out["predictive_rmse_synth"] = data.get("predictive_consistency", {}).get("synthetic_only", {}).get("rmse")
out["utility_auprc_real"] = data.get("anomaly_utility", {}).get("real_only", {}).get("auprc")
out["utility_auprc_synth"] = data.get("anomaly_utility", {}).get("synthetic_only", {}).get("auprc")
out["utility_auprc_aug"] = data.get("anomaly_utility", {}).get("real_plus_synthetic", {}).get("auprc")
post_eval_path = run_dir / "eval_post.json"
if post_eval_path.exists():
post = load_json(post_eval_path)
out["post_avg_ks"] = post.get("avg_ks")
out["post_avg_jsd"] = post.get("avg_jsd")
out["post_avg_lag1_diff"] = post.get("avg_lag1_diff")
return out
def main():
args = parse_args()
base_dir = Path(__file__).resolve().parent
out_root = Path(args.out_root)
out_root.mkdir(parents=True, exist_ok=True)
config_path = Path(args.config)
if not config_path.is_absolute():
config_path = (base_dir / config_path).resolve()
base_cfg = load_json(config_path)
variants = selected_variants(args.variants)
seeds = parse_seeds(args.seeds, base_cfg)
generated_configs: Dict[str, Path] = {}
for variant in variants:
cfg = absolutize_config_paths(base_cfg, config_path)
cfg.update(DEFAULT_ABLATIONS[variant])
cfg_path = out_root / "configs" / f"{variant}.json"
write_json(cfg_path, cfg)
generated_configs[variant] = cfg_path
history_path = out_root / "benchmark_history.csv"
summary_path = out_root / "benchmark_summary.csv"
runs_root = out_root / "runs"
rows: List[Dict[str, object]] = []
for variant in variants:
cfg_path = generated_configs[variant]
cmd = [
sys.executable,
str(base_dir / "run_all.py"),
"--config",
str(cfg_path),
"--device",
args.device,
"--runs-root",
str(runs_root),
"--benchmark-history",
str(history_path),
"--benchmark-summary",
str(summary_path),
"--seeds",
",".join(str(seed) for seed in seeds),
]
if args.skip_train:
cmd.append("--skip-train")
if args.skip_prepare:
cmd.append("--skip-prepare")
if args.skip_export:
cmd.append("--skip-export")
if args.skip_eval:
cmd.append("--skip-eval")
if args.skip_comprehensive_eval:
cmd.append("--skip-comprehensive-eval")
if args.skip_postprocess:
cmd.append("--skip-postprocess")
if args.skip_post_eval:
cmd.append("--skip-post-eval")
if args.skip_diagnostics:
cmd.append("--skip-diagnostics")
run(cmd)
for seed in seeds:
run_dir = runs_root / f"{cfg_path.stem}__seed{seed}"
row: Dict[str, object] = {"variant": variant, "seed": seed, "run_dir": str(run_dir)}
row.update(collect_metrics(run_dir))
rows.append(row)
fieldnames = sorted({key for row in rows for key in row.keys()})
csv_path = out_root / "ablation_summary.csv"
with csv_path.open("w", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for row in rows:
writer.writerow(row)
json_path = out_root / "ablation_summary.json"
write_json(json_path, {"variants": variants, "seeds": seeds, "rows": rows})
print("wrote", csv_path)
print("wrote", json_path)
if __name__ == "__main__":
main()