#!/usr/bin/env python3 """Per-feature KS diagnostics and CDF visualization (no third-party deps).""" import argparse import csv import gzip import json import math from pathlib import Path from glob import glob def parse_args(): parser = argparse.ArgumentParser(description="Per-feature KS diagnostics.") base_dir = Path(__file__).resolve().parent parser.add_argument("--generated", default=str(base_dir / "results" / "generated.csv")) parser.add_argument("--reference", default=str(base_dir / "config.json")) parser.add_argument("--out-dir", default=str(base_dir / "results")) parser.add_argument("--max-rows", type=int, default=200000, help="<=0 for full scan") parser.add_argument("--stride", type=int, default=1, help="row stride sampling") parser.add_argument("--top-k", type=int, default=8) return parser.parse_args() def load_split(base_dir: Path): with open(base_dir / "feature_split.json", "r", encoding="utf-8") as f: split = json.load(f) time_col = split.get("time_column", "time") cont_cols = [c for c in split["continuous"] if c != time_col] return cont_cols def resolve_reference_glob(base_dir: Path, ref_arg: str): ref_path = Path(ref_arg) if ref_path.suffix == ".json": cfg = json.loads(ref_path.read_text(encoding="utf-8")) data_glob = cfg.get("data_glob") or cfg.get("data_path") or "" if not data_glob: raise SystemExit("reference config has no data_glob/data_path") ref_path = (ref_path.parent / data_glob).resolve() return str(ref_path) return str(ref_path) def read_csv_values(path: Path, cols, max_rows=200000, stride=1, gz=True): values = {c: [] for c in cols} row_count = 0 reader = None if gz: fh = gzip.open(path, "rt", newline="") else: fh = open(path, "r", newline="", encoding="utf-8") try: reader = csv.DictReader(fh) for i, row in enumerate(reader): if stride > 1 and i % stride != 0: continue for c in cols: v = row.get(c, "") try: fv = float(v) if math.isfinite(fv): values[c].append(fv) except Exception: continue row_count += 1 if max_rows > 0 and row_count >= max_rows: break finally: fh.close() return values, row_count def ks_statistic(a, b): if not a or not b: return 1.0 a = sorted(a) b = sorted(b) na = len(a) nb = len(b) i = j = 0 d = 0.0 while i < na and j < nb: if a[i] <= b[j]: i += 1 else: j += 1 fa = i / na fb = j / nb d = max(d, abs(fa - fb)) return d def ecdf_points(vals): vals = sorted(vals) n = len(vals) if n == 0: return [], [] xs = [] ys = [] last = None for i, v in enumerate(vals, 1): if last is None or v != last: xs.append(v) ys.append(i / n) last = v else: ys[-1] = i / n return xs, ys def render_cdf_svg(out_path: Path, feature, real_vals, gen_vals, bounds=None): width, height = 900, 420 pad = 50 panel_w = width - pad * 2 panel_h = height - pad * 2 if not real_vals or not gen_vals: return min_v = min(min(real_vals), min(gen_vals)) max_v = max(max(real_vals), max(gen_vals)) if max_v == min_v: max_v += 1.0 rx, ry = ecdf_points(real_vals) gx, gy = ecdf_points(gen_vals) def sx(v): return pad + int((v - min_v) * panel_w / (max_v - min_v)) def sy(v): return pad + panel_h - int(v * panel_h) svg = [] svg.append(f'') svg.append('') svg.append(f'CDF 비교: {feature}') svg.append(f'') svg.append(f'') def path_from(xs, ys, color): pts = " ".join(f"{sx(x)},{sy(y)}" for x, y in zip(xs, ys)) return f'' svg.append(path_from(rx, ry, "#1f77b4")) # real svg.append(path_from(gx, gy, "#d62728")) # gen svg.append(f'real') svg.append(f'generated') if bounds is not None: lo, hi = bounds svg.append(f'') svg.append(f'') svg.append("") out_path.write_text("\n".join(svg), encoding="utf-8") def main(): args = parse_args() base_dir = Path(__file__).resolve().parent out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) cont_cols = load_split(base_dir) ref_glob = resolve_reference_glob(base_dir, args.reference) ref_files = sorted(glob(ref_glob)) if not ref_files: raise SystemExit(f"no reference files matched: {ref_glob}") gen_path = Path(args.generated) gen_vals, gen_rows = read_csv_values(gen_path, cont_cols, args.max_rows, args.stride, gz=False) # Reference values (aggregate across files) real_vals = {c: [] for c in cont_cols} total_rows = 0 for f in ref_files: vals, rows = read_csv_values(Path(f), cont_cols, args.max_rows, args.stride, gz=True) total_rows += rows for c in cont_cols: real_vals[c].extend(vals[c]) # KS per feature ks_rows = [] for c in cont_cols: ks = ks_statistic(real_vals[c], gen_vals[c]) # boundary pile-up (using real min/max) if real_vals[c]: lo = min(real_vals[c]) hi = max(real_vals[c]) else: lo = hi = 0.0 tol = (hi - lo) * 1e-4 if hi > lo else 1e-6 gen = gen_vals[c] if gen: frac_lo = sum(1 for v in gen if abs(v - lo) <= tol) / len(gen) frac_hi = sum(1 for v in gen if abs(v - hi) <= tol) / len(gen) else: frac_lo = frac_hi = 0.0 ks_rows.append((c, ks, frac_lo, frac_hi, len(real_vals[c]), len(gen_vals[c]), lo, hi)) ks_rows.sort(key=lambda x: x[1], reverse=True) out_csv = out_dir / "ks_per_feature.csv" with out_csv.open("w", newline="") as fh: w = csv.writer(fh) w.writerow(["feature", "ks", "gen_frac_at_min", "gen_frac_at_max", "real_n", "gen_n", "real_min", "real_max"]) for row in ks_rows: w.writerow(row) # top-k CDF plots for c, ks, _, _, _, _, lo, hi in ks_rows[: args.top_k]: out_svg = out_dir / f"cdf_{c}.svg" render_cdf_svg(out_svg, c, real_vals[c], gen_vals[c], bounds=(lo, hi)) # summary summary = { "generated_rows": gen_rows, "reference_rows_per_file": args.max_rows if args.max_rows > 0 else "full", "stride": args.stride, "top_k_features": [r[0] for r in ks_rows[: args.top_k]], } (out_dir / "ks_summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8") print(f"wrote {out_csv}") print(f"wrote CDF svgs for top {args.top_k} features") if __name__ == "__main__": main()