Files
mask-ddpm/example/ranked_ks.py

65 lines
2.2 KiB
Python

#!/usr/bin/env python3
"""Rank per-feature KS and show cumulative effect on avg_ks."""
import argparse
import json
from pathlib import Path
def parse_args():
parser = argparse.ArgumentParser(description="Rank KS from eval.json.")
base_dir = Path(__file__).resolve().parent
parser.add_argument("--eval", default=str(base_dir / "results" / "eval.json"))
parser.add_argument("--out", default=str(base_dir / "results" / "ranked_ks.csv"))
parser.add_argument("--top", type=int, default=20)
return parser.parse_args()
def main():
args = parse_args()
eval_path = Path(args.eval)
if not eval_path.exists():
raise SystemExit(f"missing eval.json: {eval_path}")
data = json.loads(eval_path.read_text(encoding="utf-8"))
cont_ks = data.get("continuous_ks", {})
feats = sorted(cont_ks.items(), key=lambda kv: kv[1], reverse=True)
n = len(feats)
if n == 0:
raise SystemExit("continuous_ks empty")
total = sum(v for _, v in feats)
rows = []
cumulative = 0.0
for rank, (feat, ks) in enumerate(feats, 1):
contribution = ks / n
cumulative += ks
remaining = n - rank
avg_if_removed = (total - cumulative) / remaining if remaining > 0 else None
rows.append(
{
"rank": rank,
"feature": feat,
"ks": ks,
"contribution_to_avg": contribution,
"avg_ks_if_remove_top_n": avg_if_removed,
}
)
out_path = Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True)
with out_path.open("w", encoding="utf-8") as f:
f.write("rank,feature,ks,contribution_to_avg,avg_ks_if_remove_top_n\n")
for r in rows:
avg = "" if r["avg_ks_if_remove_top_n"] is None else f"{r['avg_ks_if_remove_top_n']:.6f}"
f.write(f"{r['rank']},{r['feature']},{r['ks']:.6f},{r['contribution_to_avg']:.6f},{avg}\n")
print(f"wrote {out_path}")
print("top features:")
for r in rows[: args.top]:
avg = "NA" if r["avg_ks_if_remove_top_n"] is None else f"{r['avg_ks_if_remove_top_n']:.6f}"
print(r["rank"], r["feature"], r["ks"], avg)
if __name__ == "__main__":
main()