Add filtered KS diagnostics and feature-type plan
This commit is contained in:
65
example/filtered_metrics.py
Normal file
65
example/filtered_metrics.py
Normal file
@@ -0,0 +1,65 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Compute filtered KS/JSD by excluding hard-to-learn features."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Filtered metrics from eval.json.")
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
parser.add_argument("--eval", default=str(base_dir / "results" / "eval.json"))
|
||||
parser.add_argument("--min-std", type=float, default=1e-3, help="threshold for std collapse")
|
||||
parser.add_argument("--ks-threshold", type=float, default=0.95, help="auto-exclude if KS >= threshold")
|
||||
parser.add_argument("--out", default=str(base_dir / "results" / "filtered_metrics.json"))
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
eval_path = Path(args.eval)
|
||||
if not eval_path.exists():
|
||||
raise SystemExit(f"missing eval.json: {eval_path}")
|
||||
data = json.loads(eval_path.read_text(encoding="utf-8"))
|
||||
|
||||
cont_ks = data.get("continuous_ks", {})
|
||||
cont_stats = data.get("continuous_summary", {})
|
||||
|
||||
dropped = []
|
||||
kept = []
|
||||
ks_vals = []
|
||||
for feat, ks in cont_ks.items():
|
||||
std = None
|
||||
if feat in cont_stats:
|
||||
std = cont_stats[feat].get("std", None)
|
||||
drop = False
|
||||
if std is not None and std <= args.min_std:
|
||||
drop = True
|
||||
if ks is not None and ks >= args.ks_threshold:
|
||||
drop = True
|
||||
if drop:
|
||||
dropped.append({"feature": feat, "ks": ks, "std": std})
|
||||
else:
|
||||
kept.append(feat)
|
||||
ks_vals.append(ks)
|
||||
|
||||
filtered_avg_ks = sum(ks_vals) / len(ks_vals) if ks_vals else None
|
||||
out = {
|
||||
"filtered_avg_ks": filtered_avg_ks,
|
||||
"kept_features": kept,
|
||||
"dropped_features": dropped,
|
||||
"rules": {
|
||||
"min_std": args.min_std,
|
||||
"ks_threshold": args.ks_threshold,
|
||||
},
|
||||
"original_avg_ks": data.get("avg_ks"),
|
||||
}
|
||||
Path(args.out).write_text(json.dumps(out, indent=2), encoding="utf-8")
|
||||
print("filtered_avg_ks", filtered_avg_ks)
|
||||
print("dropped", len(dropped))
|
||||
print("wrote", args.out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user