#!/usr/bin/env python3 """Compute filtered KS/JSD by excluding hard-to-learn features.""" import argparse import json from pathlib import Path def parse_args(): parser = argparse.ArgumentParser(description="Filtered metrics from eval.json.") base_dir = Path(__file__).resolve().parent parser.add_argument("--eval", default=str(base_dir / "results" / "eval.json")) parser.add_argument("--min-std", type=float, default=1e-3, help="threshold for std collapse") parser.add_argument("--ks-threshold", type=float, default=0.95, help="auto-exclude if KS >= threshold") parser.add_argument("--out", default=str(base_dir / "results" / "filtered_metrics.json")) return parser.parse_args() def main(): args = parse_args() eval_path = Path(args.eval) if not eval_path.exists(): raise SystemExit(f"missing eval.json: {eval_path}") data = json.loads(eval_path.read_text(encoding="utf-8")) cont_ks = data.get("continuous_ks", {}) cont_stats = data.get("continuous_summary", {}) dropped = [] kept = [] ks_vals = [] for feat, ks in cont_ks.items(): std = None if feat in cont_stats: std = cont_stats[feat].get("std", None) drop = False if std is not None and std <= args.min_std: drop = True if ks is not None and ks >= args.ks_threshold: drop = True if drop: dropped.append({"feature": feat, "ks": ks, "std": std}) else: kept.append(feat) ks_vals.append(ks) filtered_avg_ks = sum(ks_vals) / len(ks_vals) if ks_vals else None out = { "filtered_avg_ks": filtered_avg_ks, "kept_features": kept, "dropped_features": dropped, "rules": { "min_std": args.min_std, "ks_threshold": args.ks_threshold, }, "original_avg_ks": data.get("avg_ks"), } Path(args.out).write_text(json.dumps(out, indent=2), encoding="utf-8") print("filtered_avg_ks", filtered_avg_ks) print("dropped", len(dropped)) print("wrote", args.out) if __name__ == "__main__": main()