Files
mask-ddpm/example/filtered_metrics.py

66 lines
2.1 KiB
Python

#!/usr/bin/env python3
"""Compute filtered KS/JSD by excluding hard-to-learn features."""
import argparse
import json
from pathlib import Path
def parse_args():
parser = argparse.ArgumentParser(description="Filtered metrics from eval.json.")
base_dir = Path(__file__).resolve().parent
parser.add_argument("--eval", default=str(base_dir / "results" / "eval.json"))
parser.add_argument("--min-std", type=float, default=1e-3, help="threshold for std collapse")
parser.add_argument("--ks-threshold", type=float, default=0.95, help="auto-exclude if KS >= threshold")
parser.add_argument("--out", default=str(base_dir / "results" / "filtered_metrics.json"))
return parser.parse_args()
def main():
args = parse_args()
eval_path = Path(args.eval)
if not eval_path.exists():
raise SystemExit(f"missing eval.json: {eval_path}")
data = json.loads(eval_path.read_text(encoding="utf-8"))
cont_ks = data.get("continuous_ks", {})
cont_stats = data.get("continuous_summary", {})
dropped = []
kept = []
ks_vals = []
for feat, ks in cont_ks.items():
std = None
if feat in cont_stats:
std = cont_stats[feat].get("std", None)
drop = False
if std is not None and std <= args.min_std:
drop = True
if ks is not None and ks >= args.ks_threshold:
drop = True
if drop:
dropped.append({"feature": feat, "ks": ks, "std": std})
else:
kept.append(feat)
ks_vals.append(ks)
filtered_avg_ks = sum(ks_vals) / len(ks_vals) if ks_vals else None
out = {
"filtered_avg_ks": filtered_avg_ks,
"kept_features": kept,
"dropped_features": dropped,
"rules": {
"min_std": args.min_std,
"ks_threshold": args.ks_threshold,
},
"original_avg_ks": data.get("avg_ks"),
}
Path(args.out).write_text(json.dumps(out, indent=2), encoding="utf-8")
print("filtered_avg_ks", filtered_avg_ks)
print("dropped", len(dropped))
print("wrote", args.out)
if __name__ == "__main__":
main()