66 lines
2.1 KiB
Python
66 lines
2.1 KiB
Python
#!/usr/bin/env python3
|
|
"""Compute filtered KS/JSD by excluding hard-to-learn features."""
|
|
|
|
import argparse
|
|
import json
|
|
from pathlib import Path
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Filtered metrics from eval.json.")
|
|
base_dir = Path(__file__).resolve().parent
|
|
parser.add_argument("--eval", default=str(base_dir / "results" / "eval.json"))
|
|
parser.add_argument("--min-std", type=float, default=1e-3, help="threshold for std collapse")
|
|
parser.add_argument("--ks-threshold", type=float, default=0.95, help="auto-exclude if KS >= threshold")
|
|
parser.add_argument("--out", default=str(base_dir / "results" / "filtered_metrics.json"))
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
eval_path = Path(args.eval)
|
|
if not eval_path.exists():
|
|
raise SystemExit(f"missing eval.json: {eval_path}")
|
|
data = json.loads(eval_path.read_text(encoding="utf-8"))
|
|
|
|
cont_ks = data.get("continuous_ks", {})
|
|
cont_stats = data.get("continuous_summary", {})
|
|
|
|
dropped = []
|
|
kept = []
|
|
ks_vals = []
|
|
for feat, ks in cont_ks.items():
|
|
std = None
|
|
if feat in cont_stats:
|
|
std = cont_stats[feat].get("std", None)
|
|
drop = False
|
|
if std is not None and std <= args.min_std:
|
|
drop = True
|
|
if ks is not None and ks >= args.ks_threshold:
|
|
drop = True
|
|
if drop:
|
|
dropped.append({"feature": feat, "ks": ks, "std": std})
|
|
else:
|
|
kept.append(feat)
|
|
ks_vals.append(ks)
|
|
|
|
filtered_avg_ks = sum(ks_vals) / len(ks_vals) if ks_vals else None
|
|
out = {
|
|
"filtered_avg_ks": filtered_avg_ks,
|
|
"kept_features": kept,
|
|
"dropped_features": dropped,
|
|
"rules": {
|
|
"min_std": args.min_std,
|
|
"ks_threshold": args.ks_threshold,
|
|
},
|
|
"original_avg_ks": data.get("avg_ks"),
|
|
}
|
|
Path(args.out).write_text(json.dumps(out, indent=2), encoding="utf-8")
|
|
print("filtered_avg_ks", filtered_avg_ks)
|
|
print("dropped", len(dropped))
|
|
print("wrote", args.out)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|