#!/usr/bin/env python3 import argparse import csv import json import math from pathlib import Path def parse_args(): parser = argparse.ArgumentParser(description="Plot benchmark metrics from benchmark_history.csv") base_dir = Path(__file__).resolve().parent parser.add_argument( "--figure", choices=["panel", "summary", "ranked_ks", "lines", "cdf_grid", "disc_grid", "disc_points"], default="panel", help="Figure type: panel (paper-style multi-panel), summary (seed robustness only), ranked_ks (outlier attribution), lines (feature series), or cdf_grid (distributions).", ) parser.add_argument( "--history", default=str(base_dir / "results" / "benchmark_history.csv"), help="Path to benchmark_history.csv", ) parser.add_argument( "--generated", default=str(base_dir / "results" / "generated.csv"), help="Path to generated.csv (used for per-feature profile in panel A).", ) parser.add_argument( "--cont-stats", default=str(base_dir / "results" / "cont_stats.json"), help="Path to cont_stats.json (used for per-feature profile in panel A).", ) parser.add_argument( "--profile-order", choices=["ks_desc", "name"], default="ks_desc", help="Feature ordering for panel A profile: ks_desc or name.", ) parser.add_argument( "--profile-max-features", type=int, default=64, help="Max features in panel A profile (0 means all).", ) parser.add_argument( "--ks-per-feature", default=str(base_dir / "results" / "ks_per_feature.csv"), help="Path to ks_per_feature.csv", ) parser.add_argument( "--data-shift", default=str(base_dir / "results" / "data_shift_stats.csv"), help="Path to data_shift_stats.csv", ) parser.add_argument( "--metrics-history", default=str(base_dir / "results" / "metrics_history.csv"), help="Path to metrics_history.csv", ) parser.add_argument( "--filtered-metrics", default=str(base_dir / "results" / "filtered_metrics.json"), help="Path to filtered_metrics.json (optional).", ) parser.add_argument( "--ranked-ks", default=str(base_dir / "results" / "ranked_ks.csv"), help="Path to ranked_ks.csv (used for --figure ranked_ks).", ) parser.add_argument( "--ranked-ks-top-n", type=int, default=30, help="Number of top KS features to show in ranked_ks figure.", ) parser.add_argument( "--out", default="", help="Output SVG path (default depends on --figure).", ) parser.add_argument( "--engine", choices=["auto", "matplotlib", "svg"], default="auto", help="Plotting engine: auto prefers matplotlib if available; otherwise uses pure-SVG.", ) parser.add_argument( "--lines-features", default="", help="Comma-separated feature names for --figure lines (default: top-4 from ranked_ks.csv or fallback set).", ) parser.add_argument( "--lines-top-k", type=int, default=8, help="When features not specified, take top-K features from ranked_ks.csv.", ) parser.add_argument( "--lines-max-rows", type=int, default=1000, help="Max rows to read from generated.csv for --figure lines.", ) parser.add_argument( "--lines-normalize", choices=["none", "real_range"], default="none", help="Normalization for --figure lines: none or real_range (use cont_stats min/max).", ) parser.add_argument( "--reference", default=str(base_dir / "config.json"), help="Reference source for real data: config.json with data_glob or direct CSV/GZ path.", ) parser.add_argument( "--lines-ref-index", type=int, default=0, help="Index of matched reference file to plot (0-based) when using a glob.", ) parser.add_argument( "--cdf-features", default="", help="Comma-separated features for cdf_grid; empty = all continuous features from cont_stats.", ) parser.add_argument( "--cdf-max-features", type=int, default=64, help="Max features to include in cdf_grid.", ) parser.add_argument( "--cdf-bins", type=int, default=80, help="Number of bins for empirical CDF.", ) parser.add_argument( "--disc-features", default="", help="Comma-separated features for disc_grid; empty = use discrete list from feature_split.json.", ) parser.add_argument( "--disc-max-features", type=int, default=64, help="Max features to include in disc_grid.", ) parser.add_argument( "--feature-split", default=str(base_dir / "feature_split.json"), help="Path to feature_split.json with 'continuous' and 'discrete' lists.", ) return parser.parse_args() def mean_std(vals): m = sum(vals) / len(vals) if len(vals) == 1: return m, 0.0 v = sum((x - m) * (x - m) for x in vals) / (len(vals) - 1) return m, math.sqrt(v) def svg_escape(s): return ( str(s) .replace("&", "&") .replace("<", "<") .replace(">", ">") .replace('"', """) .replace("'", "'") ) def clamp(v, lo, hi): if v < lo: return lo if v > hi: return hi return v def lerp(a, b, t): return a + (b - a) * t def hex_to_rgb(h): h = h.lstrip("#") return int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16) def rgb_to_hex(r, g, b): return "#{:02x}{:02x}{:02x}".format(int(clamp(r, 0, 255)), int(clamp(g, 0, 255)), int(clamp(b, 0, 255))) def diverging_color(v, vmin=-2.0, vmax=2.0, cold="#2563eb", hot="#ef4444", mid="#ffffff"): v = clamp(v, vmin, vmax) if v >= 0: t = 0.0 if vmax == 0 else v / vmax r0, g0, b0 = hex_to_rgb(mid) r1, g1, b1 = hex_to_rgb(hot) return rgb_to_hex(lerp(r0, r1, t), lerp(g0, g1, t), lerp(b0, b1, t)) t = 0.0 if vmin == 0 else (-v) / (-vmin) r0, g0, b0 = hex_to_rgb(mid) r1, g1, b1 = hex_to_rgb(cold) return rgb_to_hex(lerp(r0, r1, t), lerp(g0, g1, t), lerp(b0, b1, t)) def plot_matplotlib(rows, seeds, metrics, out_path): import matplotlib.pyplot as plt try: plt.style.use("seaborn-v0_8-whitegrid") except Exception: pass fig, axes = plt.subplots(nrows=len(metrics), ncols=1, figsize=(8.6, 4.6), sharex=False) if len(metrics) == 1: axes = [axes] point_color = "#3b82f6" band_color = "#ef4444" grid_color = "#e5e7eb" axis_color = "#111827" for ax, (key, title) in zip(axes, metrics, strict=True): vals = [r[key] for r in rows] m, s = mean_std(vals) vmin = min(vals + [m - s]) vmax = max(vals + [m + s]) if vmax == vmin: vmax = vmin + 1.0 vr = vmax - vmin vmin -= 0.20 * vr vmax += 0.20 * vr y0 = 0.0 jitter = [-0.08, 0.0, 0.08] ys = [(y0 + jitter[i % len(jitter)]) for i in range(len(vals))] ax.axvspan(m - s, m + s, color=band_color, alpha=0.10, linewidth=0) ax.axvline(m, color=band_color, linewidth=2.2) ax.scatter(vals, ys, s=46, color=point_color, edgecolors="white", linewidths=1.0, zorder=3) for x, y, seed in zip(vals, ys, seeds, strict=True): ax.annotate( str(seed), (x, y), textcoords="offset points", xytext=(0, 10), ha="center", va="bottom", fontsize=8, color=axis_color, ) ax.set_title(title, loc="left", fontsize=11, color=axis_color, pad=8) ax.set_yticks([]) ax.set_ylim(-0.35, 0.35) ax.set_xlim(vmin, vmax) ax.grid(True, axis="x", color=grid_color) ax.grid(False, axis="y") ax.text( 0.99, 0.80, "mean={m:.4f} ± {s:.4f}".format(m=m, s=s), transform=ax.transAxes, ha="right", va="center", fontsize=9, color="#374151", ) fig.suptitle("Benchmark Metrics (3 seeds) · lower is better", fontsize=12, color=axis_color, y=0.98) fig.tight_layout(rect=(0, 0, 1, 0.95)) out_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_path, format="svg") plt.close(fig) def discrete_points_matplotlib(generated_csv_path, reference_arg, features, out_path, max_rows=5000): import matplotlib.pyplot as plt import numpy as np import gzip try: plt.style.use("seaborn-v0_8-whitegrid") except Exception: pass def resolve_reference_glob(ref_arg: str) -> str: ref_path = Path(ref_arg) if ref_path.suffix == ".json": cfg = json.loads(ref_path.read_text(encoding="utf-8")) data_glob = cfg.get("data_glob") or cfg.get("data_path") or "" if not data_glob: raise SystemExit("reference config has no data_glob/data_path") combined = ref_path.parent / data_glob if "*" in str(combined) or "?" in str(combined): return str(combined) return str(combined.resolve()) return str(ref_path) def read_rows(path, limit): rows = [] opener = gzip.open if str(path).endswith(".gz") else open with opener(path, "rt", newline="") as fh: reader = csv.DictReader(fh) for i, r in enumerate(reader): rows.append(r) if limit > 0 and i + 1 >= limit: break return rows def cats(rows, feat): vs = [] for r in rows: v = r.get(feat) if v is None: continue s = str(v).strip() if s == "" or s.lower() == "nan": continue vs.append(s) return vs ref_glob = resolve_reference_glob(reference_arg) ref_paths = sorted(Path(ref_glob).parent.glob(Path(ref_glob).name)) gen_rows = read_rows(generated_csv_path, max_rows) ref_rows = read_rows(ref_paths[0] if ref_paths else generated_csv_path, max_rows) points = [] group_spans = [] x = 0 for feat in features: gvs = cats(gen_rows, feat) rvs = cats(ref_rows, feat) cats_all = sorted(set(gvs) | set(rvs)) start_x = x for c in cats_all: g_count = sum(1 for v in gvs if v == c) r_count = sum(1 for v in rvs if v == c) g_total = len(gvs) or 1 r_total = len(rvs) or 1 g_p = g_count / g_total r_p = r_count / r_total points.append({"x": x, "feat": feat, "cat": c, "g": g_p, "r": r_p}) x += 1 end_x = x - 1 if end_x >= start_x: group_spans.append({"feat": feat, "x0": start_x, "x1": end_x}) x += 1 if not points: fig, ax = plt.subplots(figsize=(9, 3)) ax.axis("off") ax.text(0.5, 0.5, "no discrete data", ha="center", va="center") out_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_path, format="svg") plt.close(fig) return width = max(9.0, min(18.0, 0.08 * len(points))) fig, ax = plt.subplots(figsize=(width, 6.2)) xs = np.array([p["x"] for p in points], dtype=float) jitter = 0.18 ax.scatter(xs - jitter, [p["g"] for p in points], s=26, color="#2563eb", alpha=0.85, edgecolors="white", linewidths=0.6, label="generated") ax.scatter(xs + jitter, [p["r"] for p in points], s=26, color="#ef4444", alpha=0.75, edgecolors="white", linewidths=0.6, label="real") for span in group_spans: xc = (span["x0"] + span["x1"]) / 2.0 ax.axvline(span["x1"] + 0.5, color="#e5e7eb", lw=1.0) ax.text(xc, -0.06, span["feat"], ha="center", va="top", rotation=25, fontsize=8, color="#374151") xg_all = xs - jitter yg_all = [p["g"] for p in points] xr_all = xs + jitter yr_all = [p["r"] for p in points] ax.plot(xg_all, yg_all, color="#2563eb", linewidth=1.2, alpha=0.85) ax.plot(xr_all, yr_all, color="#ef4444", linewidth=1.2, alpha=0.75) ax.fill_between(xg_all, yg_all, 0.0, color="#2563eb", alpha=0.14, step=None) ax.fill_between(xr_all, yr_all, 0.0, color="#ef4444", alpha=0.14, step=None) ax.set_ylim(-0.08, 1.02) ax.set_xlim(-0.5, max(xs) + 0.5) ax.set_ylabel("probability", fontsize=10) ax.set_xticks([]) ax.grid(True, axis="y", color="#e5e7eb") ax.legend(loc="upper right", fontsize=9) fig.suptitle("Discrete Marginals (dot plot): generated vs real", fontsize=12, color="#111827", y=0.98) fig.tight_layout(rect=(0, 0, 1, 0.96)) out_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_path, format="svg") plt.close(fig) def plot_svg(rows, seeds, metrics, out_path): W, H = 980, 440 pad_l, pad_r, pad_t, pad_b = 200, 30, 74, 36 row_gap = 52 row_h = (H - pad_t - pad_b - row_gap * (len(metrics) - 1)) / len(metrics) bg = "#ffffff" axis = "#2b2b2b" grid = "#e9e9e9" band = "#d62728" band_fill = "#d62728" point = "#1f77b4" text = "#111111" subtle = "#666666" parts = [] parts.append( "".format( w=W, h=H ) ) parts.append("".format(w=W, h=H, bg=bg)) parts.append( "Benchmark Metrics (3 seeds)".format( x=W / 2, c=text ) ) parts.append( "line: mean · band: ±1 std · dots: runs · lower is better".format( x=W / 2, c=subtle ) ) parts.append( "seeds: {s}".format( x=W / 2, c=subtle, s=svg_escape(", ".join(seeds)) ) ) plot_x0 = pad_l plot_x1 = W - pad_r for mi, (key, title) in enumerate(metrics): y0 = pad_t + mi * (row_h + row_gap) y1 = y0 + row_h yc = (y0 + y1) / 2 vals = [r[key] for r in rows] m, s = mean_std(vals) vmin = min(vals + [m - s]) vmax = max(vals + [m + s]) if vmax == vmin: vmax = vmin + 1.0 vr = vmax - vmin vmin -= 0.20 * vr vmax += 0.20 * vr def X(v): return plot_x0 + (v - vmin) * (plot_x1 - plot_x0) / (vmax - vmin) if mi % 2 == 1: parts.append( "".format( x=0, y=y0 - 8, w=W, h=row_h + 16 ) ) parts.append( "{t}".format( x=pad_l - 12, y=yc + 4, c=text, t=svg_escape(title) ) ) for k in range(6): xx = plot_x0 + k * (plot_x1 - plot_x0) / 5 parts.append( "".format( x=xx, y0=y0 + 4, y1=y1 - 4, c=grid ) ) val = vmin + k * (vmax - vmin) / 5 parts.append( "{v:.4f}".format( x=xx, y=y1 + 28, c=subtle, v=val ) ) parts.append( "".format( x0=plot_x0, x1=plot_x1, y=yc, c=axis ) ) x_lo = X(m - s) x_hi = X(m + s) parts.append( "".format( x=min(x_lo, x_hi), y=yc - 14, w=abs(x_hi - x_lo), h=28, c=band_fill ) ) xm = X(m) parts.append( "".format( x=xm, y0=yc - 16, y1=yc + 16, c=band ) ) parts.append( "mean={m:.4f}±{s:.4f}".format( x=xm + 6, y=yc - 18, c=subtle, m=m, s=s ) ) for i, r in enumerate(rows): jitter = ((i * 37) % 11 - 5) * 0.9 xx = X(r[key]) yy = yc + jitter parts.append("".format(x=xx, y=yy, c=point)) parts.append("") out_path.parent.mkdir(parents=True, exist_ok=True) out_path.write_text("\n".join(parts), encoding="utf-8") def ranked_ks_svg(ranked_rows, out_path, top_n=30): parsed = [] for r in ranked_rows or []: feat = (r.get("feature") or "").strip() if not feat: continue ks = parse_float(r.get("ks")) if ks is None: continue rank = parse_float(r.get("rank")) contrib = parse_float(r.get("contribution_to_avg")) avg_if = parse_float(r.get("avg_ks_if_remove_top_n")) parsed.append( { "rank": int(rank) if isinstance(rank, (int, float)) else None, "feature": feat, "ks": float(ks), "contrib": float(contrib) if isinstance(contrib, (int, float)) else None, "avg_if": float(avg_if) if isinstance(avg_if, (int, float)) else None, } ) if not parsed: raise SystemExit("no valid rows in ranked_ks.csv") parsed_by_ks = sorted(parsed, key=lambda x: x["ks"], reverse=True) top_n = max(1, int(top_n)) top = parsed_by_ks[: min(top_n, len(parsed_by_ks))] parsed_by_rank = sorted( [p for p in parsed if isinstance(p.get("rank"), int) and p.get("avg_if") is not None], key=lambda x: x["rank"], ) baseline = None if parsed_by_rank: first = parsed_by_rank[0] if first.get("avg_if") is not None and first.get("contrib") is not None: baseline = first["avg_if"] + first["contrib"] elif first.get("avg_if") is not None: baseline = first["avg_if"] if baseline is None: baseline = sum(p["ks"] for p in parsed_by_ks) / len(parsed_by_ks) xs = [0] ys = [baseline] for p in parsed_by_rank: xs.append(p["rank"]) ys.append(p["avg_if"]) W, H = 980, 560 bg = "#ffffff" ink = "#0f172a" subtle = "#64748b" grid = "#e2e8f0" border = "#cbd5e1" blue = "#2563eb" bar = "#0ea5e9" ff = "system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif" def text(x, y, s, size=12, anchor="start", color=ink, weight="normal"): return "{t}".format( x=x, y=y, a=anchor, ff=ff, fs=size, c=color, w=weight, t=svg_escape(s), ) def line(x1, y1, x2, y2, color=grid, width=1.0, dash=None, opacity=1.0, cap="round"): extra = "" if dash: extra += " stroke-dasharray='{d}'".format(d=dash) if opacity != 1.0: extra += " stroke-opacity='{o}'".format(o=opacity) return "".format( x1=x1, y1=y1, x2=x2, y2=y2, c=color, w=width, cap=cap, extra=extra, ) pad = 26 title_h = 62 gap = 18 card_w = (W - 2 * pad - gap) / 2 card_h = H - pad - title_h xA = pad xB = pad + card_w + gap y0 = title_h parts = [] parts.append("".format(w=W, h=H)) parts.append("".format(w=W, h=H, bg=bg)) parts.append(text(W / 2, 28, "KS Outlier Attribution", size=16, anchor="middle", color=ink, weight="bold")) parts.append( text( W / 2, 48, "Left: top-K features by KS · Right: avg KS after removing top-n outliers", size=11, anchor="middle", color=subtle, weight="normal", ) ) parts.append( "".format( x=xA, y=y0, w=card_w, h=card_h, b=border ) ) parts.append( "".format( x=xB, y=y0, w=card_w, h=card_h, b=border ) ) parts.append(text(xA + 18, y0 + 28, "A Top-{k} KS features".format(k=len(top)), size=12, color=ink, weight="bold")) parts.append(text(xB + 18, y0 + 28, "B Removing worst features", size=12, color=ink, weight="bold")) ax_y0 = y0 + 52 ax_h = card_h - 76 label_w = 165 ax_x0 = xA + 18 + label_w ax_x1 = xA + card_w - 18 row_h = ax_h / max(1, len(top)) for t in range(6): xx = ax_x0 + (ax_x1 - ax_x0) * (t / 5) parts.append(line(xx, ax_y0, xx, ax_y0 + ax_h, color=grid, width=1.0)) parts.append(text(xx, ax_y0 + ax_h + 20, "{:.1f}".format(t / 5), size=9, anchor="middle", color=subtle)) parts.append(line(ax_x0, ax_y0 + ax_h, ax_x1, ax_y0 + ax_h, color=border, width=1.2, cap="butt")) parts.append(text(ax_x0, ax_y0 + ax_h + 38, "KS", size=10, anchor="start", color=subtle, weight="bold")) for i, p in enumerate(top): cy = ax_y0 + i * row_h + row_h / 2 parts.append(text(ax_x0 - 10, cy + 4, p["feature"], size=9, anchor="end", color=ink)) w = (ax_x1 - ax_x0) * clamp(p["ks"], 0.0, 1.0) parts.append( "".format( x=ax_x0, y=cy - row_h * 0.34, w=w, h=row_h * 0.68, c=bar ) ) parts.append(text(ax_x1, cy + 4, "{:.3f}".format(p["ks"]), size=9, anchor="end", color=subtle)) if p.get("contrib") is not None: parts.append(text(ax_x1 + 10, cy + 4, "{:.2f}%".format(100.0 * p["contrib"]), size=9, anchor="start", color=subtle)) px0 = xB + 54 py0 = y0 + 70 pw = card_w - 78 ph = card_h - 108 parts.append(text(px0, py0 - 20, "avg KS (lower is better)", size=10, anchor="start", color=subtle, weight="bold")) xmax = max(xs) if xs else 1 ymin = min(ys) if ys else 0.0 ymax = max(ys) if ys else 1.0 if ymax == ymin: ymax = ymin + 1.0 yr = ymax - ymin ymin -= 0.12 * yr ymax += 0.12 * yr def X(v): return px0 + (v / max(1e-9, xmax)) * pw def Y(v): return py0 + ph - ((v - ymin) / (ymax - ymin)) * ph for k in range(6): yy = py0 + (ph * k / 5) parts.append(line(px0, yy, px0 + pw, yy, color=grid, width=1.0)) val = ymax - (ymax - ymin) * (k / 5) parts.append(text(px0 - 8, yy + 4, "{:.3f}".format(val), size=9, anchor="end", color=subtle)) for k in range(6): xx = px0 + (pw * k / 5) parts.append(line(xx, py0, xx, py0 + ph, color=grid, width=1.0)) val = int(round(xmax * (k / 5))) parts.append(text(xx, py0 + ph + 22, str(val), size=9, anchor="middle", color=subtle)) parts.append(text(px0 + pw / 2, py0 + ph + 42, "remove top-n features", size=10, anchor="middle", color=subtle, weight="bold")) d = [] for x, y in zip(xs, ys, strict=True): d.append(("M" if not d else "L") + " {x:.1f} {y:.1f}".format(x=X(x), y=Y(y))) parts.append("".format(d=" ".join(d), c=blue)) for x, y in zip(xs, ys, strict=True): parts.append("".format(x=X(x), y=Y(y), c=blue)) parts.append("") out_path.parent.mkdir(parents=True, exist_ok=True) out_path.write_text("\n".join(parts), encoding="utf-8") def lines_matplotlib(generated_csv_path, cont_stats, features, out_path, max_rows=1000, normalize="none", reference_arg="", ref_index=0): import matplotlib.pyplot as plt import gzip try: plt.style.use("seaborn-v0_8-whitegrid") except Exception: pass rows_gen = [] with Path(generated_csv_path).open("r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) for i, r in enumerate(reader): rows_gen.append(r) if len(rows_gen) >= max_rows: break xs_gen = [parse_float(r.get("time")) or (i if (r.get("time") is None) else 0.0) for i, r in enumerate(rows_gen)] mins = {} maxs = {} if isinstance(cont_stats, dict): mins = cont_stats.get("min", {}) or {} maxs = cont_stats.get("max", {}) or {} def norm_val(feat, v): if normalize != "real_range": return v mn = parse_float(mins.get(feat)) mx = parse_float(maxs.get(feat)) if mn is None or mx is None: return v denom = mx - mn if denom == 0: return v return (v - mn) / denom def resolve_reference_glob(ref_arg: str) -> str: ref_path = Path(ref_arg) if ref_path.suffix == ".json": cfg = json.loads(ref_path.read_text(encoding="utf-8")) data_glob = cfg.get("data_glob") or cfg.get("data_path") or "" if not data_glob: raise SystemExit("reference config has no data_glob/data_path") combined = ref_path.parent / data_glob if "*" in str(combined) or "?" in str(combined): return str(combined) return str(combined.resolve()) return str(ref_path) def read_series(path: Path, cols, max_rows: int): vals = {c: [] for c in cols} opener = gzip.open if str(path).endswith(".gz") else open with opener(path, "rt", newline="") as fh: reader = csv.DictReader(fh) for i, row in enumerate(reader): for c in cols: try: vals[c].append(float(row[c])) except Exception: pass if max_rows > 0 and i + 1 >= max_rows: break return vals ref_glob = resolve_reference_glob(reference_arg or str(Path(__file__).resolve().parent / "config.json")) ref_paths = sorted(Path(ref_glob).parent.glob(Path(ref_glob).name)) ref_rows = [] if ref_paths: idx = max(0, min(ref_index, len(ref_paths) - 1)) first = ref_paths[idx] opener = gzip.open if str(first).endswith(".gz") else open with opener(first, "rt", newline="") as fh: reader = csv.DictReader(fh) for i, r in enumerate(reader): ref_rows.append(r) if len(ref_rows) >= max_rows: break xs_ref = [i for i, _ in enumerate(ref_rows)] fig, axes = plt.subplots(nrows=len(features), ncols=1, figsize=(9.2, 6.6), sharex=True) if len(features) == 1: axes = [axes] for ax, feat in zip(axes, features, strict=True): ys_gen = [norm_val(feat, parse_float(r.get(feat)) or 0.0) for r in rows_gen] ax.plot(xs_gen, ys_gen, color="#2563eb", linewidth=1.6, label="generated") if ref_rows: ys_ref = [norm_val(feat, parse_float(r.get(feat)) or 0.0) for r in ref_rows] ax.plot(xs_ref, ys_ref, color="#ef4444", linewidth=1.2, alpha=0.75, label="real") ax.set_ylabel(feat, fontsize=10) ax.grid(True, color="#e5e7eb") ax.legend(loc="upper right", fontsize=8) axes[-1].set_xlabel("time", fontsize=10) fig.suptitle("Feature Series: generated vs real", fontsize=12, color="#111827", y=0.98) fig.tight_layout(rect=(0, 0, 1, 0.96)) out_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_path, format="svg") plt.close(fig) def cdf_grid_matplotlib(generated_csv_path, reference_arg, cont_stats, features, out_path, max_rows=5000, bins=80): import matplotlib.pyplot as plt import gzip import numpy as np try: plt.style.use("seaborn-v0_8-whitegrid") except Exception: pass mins = cont_stats.get("min", {}) if isinstance(cont_stats, dict) else {} maxs = cont_stats.get("max", {}) if isinstance(cont_stats, dict) else {} def resolve_reference_glob(ref_arg: str) -> str: ref_path = Path(ref_arg) if ref_path.suffix == ".json": cfg = json.loads(ref_path.read_text(encoding="utf-8")) data_glob = cfg.get("data_glob") or cfg.get("data_path") or "" if not data_glob: raise SystemExit("reference config has no data_glob/data_path") combined = ref_path.parent / data_glob if "*" in str(combined) or "?" in str(combined): return str(combined) return str(combined.resolve()) return str(ref_path) def read_rows(path, limit): rows = [] opener = gzip.open if str(path).endswith(".gz") else open with opener(path, "rt", newline="") as fh: reader = csv.DictReader(fh) for i, r in enumerate(reader): rows.append(r) if limit > 0 and i + 1 >= limit: break return rows gen_rows = read_rows(generated_csv_path, max_rows) ref_glob = resolve_reference_glob(reference_arg) ref_paths = sorted(Path(ref_glob).parent.glob(Path(ref_glob).name)) ref_rows = read_rows(ref_paths[0] if ref_paths else generated_csv_path, max_rows) def values(rows, feat): vs = [] for r in rows: x = parse_float(r.get(feat)) if x is not None: vs.append(x) return vs cols = 4 rows_n = int(math.ceil(len(features) / cols)) if features else 1 fig, axes = plt.subplots(nrows=rows_n, ncols=cols, figsize=(cols * 3.2, rows_n * 2.6)) axes = np.array(axes).reshape(rows_n, cols) for i, feat in enumerate(features): rr = i // cols cc = i % cols ax = axes[rr][cc] gvs = values(gen_rows, feat) rvs = values(ref_rows, feat) mn = parse_float(mins.get(feat)) mx = parse_float(maxs.get(feat)) if mn is None or mx is None or mx <= mn: lo = min(gvs + rvs) if (gvs or rvs) else 0.0 hi = max(gvs + rvs) if (gvs or rvs) else (lo + 1.0) else: lo = mn hi = mx edges = np.linspace(lo, hi, bins + 1) def ecdf(vs): if not vs: return edges[1:], np.zeros_like(edges[1:]) hist, _ = np.histogram(vs, bins=edges) cdf = np.cumsum(hist).astype(float) cdf /= cdf[-1] if cdf[-1] > 0 else 1.0 xs = edges[1:] return xs, cdf xg, yg = ecdf(gvs) xr, yr = ecdf(rvs) ax.plot(xg, yg, color="#2563eb", linewidth=1.6, label="generated") ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real") ax.set_title(feat, fontsize=9, loc="left") ax.set_ylim(0, 1) ax.grid(True, color="#e5e7eb") for j in range(i + 1, rows_n * cols): rr = j // cols cc = j % cols axes[rr][cc].axis("off") handles, labels = axes[0][0].get_legend_handles_labels() fig.legend(handles, labels, loc="upper center", ncol=2, fontsize=9) fig.suptitle("Empirical CDF: generated vs real", fontsize=12, color="#111827", y=0.98) fig.tight_layout(rect=(0, 0, 1, 0.96)) out_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_path, format="svg") plt.close(fig) def discrete_grid_matplotlib(generated_csv_path, reference_arg, features, out_path, max_rows=5000): import matplotlib.pyplot as plt import numpy as np import gzip try: plt.style.use("seaborn-v0_8-whitegrid") except Exception: pass def resolve_reference_glob(ref_arg: str) -> str: ref_path = Path(ref_arg) if ref_path.suffix == ".json": cfg = json.loads(ref_path.read_text(encoding="utf-8")) data_glob = cfg.get("data_glob") or cfg.get("data_path") or "" if not data_glob: raise SystemExit("reference config has no data_glob/data_path") combined = ref_path.parent / data_glob if "*" in str(combined) or "?" in str(combined): return str(combined) return str(combined.resolve()) return str(ref_path) def read_rows(path, limit): rows = [] opener = gzip.open if str(path).endswith(".gz") else open with opener(path, "rt", newline="") as fh: reader = csv.DictReader(fh) for i, r in enumerate(reader): rows.append(r) if limit > 0 and i + 1 >= limit: break return rows ref_glob = resolve_reference_glob(reference_arg) ref_paths = sorted(Path(ref_glob).parent.glob(Path(ref_glob).name)) gen_rows = read_rows(generated_csv_path, max_rows) ref_rows = read_rows(ref_paths[0] if ref_paths else generated_csv_path, max_rows) def cats(rows, feat): vs = [] for r in rows: v = r.get(feat) if v is None: continue s = str(v).strip() if s == "" or s.lower() == "nan": continue vs.append(s) return vs cols = 4 rows_n = int(math.ceil(len(features) / cols)) if features else 1 fig, axes = plt.subplots(nrows=rows_n, ncols=cols, figsize=(cols * 3.0, rows_n * 2.6), sharey=False) axes = np.array(axes).reshape(rows_n, cols) for i, feat in enumerate(features): rr = i // cols cc = i % cols ax = axes[rr][cc] gvs = cats(gen_rows, feat) rvs = cats(ref_rows, feat) all_vals = sorted(set(gvs) | set(rvs)) if not all_vals: ax.axis("off") continue g_counts = {v: 0 for v in all_vals} r_counts = {v: 0 for v in all_vals} for v in gvs: g_counts[v] += 1 for v in rvs: r_counts[v] += 1 g_total = sum(g_counts.values()) or 1 r_total = sum(r_counts.values()) or 1 g_p = [g_counts[v] / g_total for v in all_vals] r_p = [r_counts[v] / r_total for v in all_vals] x = np.arange(len(all_vals)) w = 0.42 ax.bar(x - w / 2, g_p, width=w, color="#2563eb", alpha=0.85, label="generated") ax.bar(x + w / 2, r_p, width=w, color="#ef4444", alpha=0.75, label="real") ax.set_title(feat, fontsize=9, loc="left") ax.set_xticks(x) ax.set_xticklabels(all_vals, rotation=25, ha="right", fontsize=8) ax.set_ylim(0, 1) ax.grid(True, axis="y", color="#e5e7eb") for j in range(i + 1, rows_n * cols): rr = j // cols cc = j % cols axes[rr][cc].axis("off") handles, labels = axes[0][0].get_legend_handles_labels() fig.legend(handles, labels, loc="upper center", ncol=2, fontsize=9) fig.suptitle("Discrete Marginals: generated vs real", fontsize=12, color="#111827", y=0.98) fig.tight_layout(rect=(0, 0, 1, 0.96)) out_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_path, format="svg") plt.close(fig) def read_csv_rows(path): p = Path(path) if not p.exists(): return [] with p.open("r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) return list(reader) def read_json(path): p = Path(path) if not p.exists(): return None return json.loads(p.read_text(encoding="utf-8")) def parse_float(s): if s is None: return None ss = str(s).strip() if ss == "" or ss.lower() == "none" or ss.lower() == "nan": return None return float(ss) def compute_csv_means(path): p = Path(path) if not p.exists(): return {} with p.open("r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) sums = {} counts = {} for r in reader: for k, v in r.items(): x = parse_float(v) if x is None: continue sums[k] = sums.get(k, 0.0) + x counts[k] = counts.get(k, 0) + 1 means = {} for k, s in sums.items(): c = counts.get(k, 0) if c > 0: means[k] = s / c return means def build_mean_profile(ks_rows, cont_stats, generated_csv_path, order="ks_desc", max_features=64): if not isinstance(cont_stats, dict): return None means_real = cont_stats.get("mean") if not isinstance(means_real, dict): return None means_gen = compute_csv_means(generated_csv_path) if not means_gen: return None ks_by_feat = {} mins = {} maxs = {} for r in ks_rows or []: feat = (r.get("feature") or "").strip() if not feat: continue ks = parse_float(r.get("ks")) mn = parse_float(r.get("real_min")) mx = parse_float(r.get("real_max")) if ks is not None: ks_by_feat[feat] = ks if mn is not None: mins[feat] = mn if mx is not None: maxs[feat] = mx feats = [] real_vals = [] gen_vals = [] ks_vals = [] for feat, mu_real in means_real.items(): if feat not in means_gen: continue if feat not in mins or feat not in maxs: continue mn = mins[feat] mx = maxs[feat] denom = mx - mn if denom == 0: continue mu_gen = means_gen[feat] feats.append(feat) real_vals.append(clamp((mu_real - mn) / denom, 0.0, 1.0)) gen_vals.append(clamp((mu_gen - mn) / denom, 0.0, 1.0)) ks_vals.append(ks_by_feat.get(feat)) if not feats: return None idx = list(range(len(feats))) if order == "name": idx.sort(key=lambda i: feats[i]) else: idx.sort(key=lambda i: ks_by_feat.get(feats[i], -1.0), reverse=True) if isinstance(max_features, int) and max_features > 0 and len(idx) > max_features: idx = idx[:max_features] sel_feats = [feats[i] for i in idx] sel_real = [real_vals[i] for i in idx] sel_gen = [gen_vals[i] for i in idx] sel_ks = [ks_vals[i] for i in idx] sel_diff = [abs(a - b) for a, b in zip(sel_real, sel_gen, strict=True)] def pearsonr(x, y): if not x or len(x) != len(y): return None n = len(x) mx = sum(x) / n my = sum(y) / n vx = sum((xi - mx) ** 2 for xi in x) vy = sum((yi - my) ** 2 for yi in y) if vx <= 0 or vy <= 0: return None cov = sum((xi - mx) * (yi - my) for xi, yi in zip(x, y, strict=True)) return cov / math.sqrt(vx * vy) r = pearsonr(sel_real, sel_gen) mae = (sum(sel_diff) / len(sel_diff)) if sel_diff else None points = [] for f, xr, yg, ks, d in zip(sel_feats, sel_real, sel_gen, sel_ks, sel_diff, strict=True): points.append({"feature": f, "x": xr, "y": yg, "ks": ks, "diff": d}) return { "features": sel_feats, "real": sel_real, "gen": sel_gen, "ks": sel_ks, "diff": sel_diff, "points": points, "stats": {"n": len(sel_feats), "r": r, "mae": mae}, } def zscores(vals): if not vals: return [] m = sum(vals) / len(vals) v = sum((x - m) * (x - m) for x in vals) / len(vals) s = math.sqrt(v) if s == 0: return [0.0 for _ in vals] return [(x - m) / s for x in vals] def panel_svg(bh_rows, ks_rows, shift_rows, hist_rows, filtered_metrics, profile, out_path): W, H = 1400, 900 margin = 42 gap = 26 panel_w = (W - margin * 2 - gap) / 2 panel_h = (H - margin * 2 - gap) / 2 bg = "#ffffff" ink = "#111827" subtle = "#6b7280" border = "#e5e7eb" grid = "#eef2f7" blue = "#3b82f6" red = "#ef4444" green = "#10b981" card_bg = "#ffffff" card_shadow = "#0f172a" plot_bg = "#f8fafc" def panel_rect(x, y): return ( "" "" ).format( x=x + 2.0, y=y + 3.0, x0=x, y0=y, w=panel_w, h=panel_h, f=card_bg, b=border, s=card_shadow, ) def text(x, y, s, size=12, anchor="start", color=ink, weight="normal"): return "{t}".format( x=x, y=y, a=anchor, ff="system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif", fs=size, c=color, w=weight, t=svg_escape(s), ) def line(x1, y1, x2, y2, color=border, width=1.0, dash=None, opacity=1.0, cap="round"): extra = "" if dash: extra += " stroke-dasharray='{d}'".format(d=dash) if opacity != 1.0: extra += " stroke-opacity='{o}'".format(o=opacity) return "".format( x1=x1, y1=y1, x2=x2, y2=y2, c=color, w=width, cap=cap, extra=extra ) def round_box(x, y, w, h, fill="#ffffff", stroke=border, sw=1.2, rx=12): return "".format( x=x, y=y, w=w, h=h, rx=rx, f=fill, s=stroke, sw=sw ) def arrow(x1, y1, x2, y2, color=ink, width=1.8): ang = math.atan2(y2 - y1, x2 - x1) ah = 10.0 aw = 5.0 hx = x2 - ah * math.cos(ang) hy = y2 - ah * math.sin(ang) px = aw * math.sin(ang) py = -aw * math.cos(ang) p1x, p1y = hx + px, hy + py p2x, p2y = hx - px, hy - py return ( "" "" ).format(x1=x1, y1=y1, x2=x2, y2=y2, c=color, w=width, p1x=p1x, p1y=p1y, p2x=p2x, p2y=p2y) parts = [] parts.append( "".format(w=W, h=H) ) parts.append("".format(w=W, h=H, bg=bg)) parts.append( "" "" "" "" "" ) parts.append(text(W / 2, 32, "Benchmark Overview (HAI Security Dataset)", size=18, anchor="middle", weight="bold")) parts.append( text( W / 2, 54, "A: per-feature mean profile · B: per-feature KS · C: train-file mean shift · D: seed robustness and metric history", size=11, anchor="middle", color=subtle, ) ) xA, yA = margin, margin + 36 xB, yB = margin + panel_w + gap, margin + 36 xC, yC = margin, margin + 36 + panel_h + gap xD, yD = margin + panel_w + gap, margin + 36 + panel_h + gap parts.append(panel_rect(xA, yA)) parts.append(panel_rect(xB, yB)) parts.append(panel_rect(xC, yC)) parts.append(panel_rect(xD, yD)) def panel_label(x, y, letter, title): parts.append(text(x + 18, y + 28, letter, size=16, weight="bold")) parts.append(text(x + 44, y + 28, title, size=14, weight="bold")) panel_label(xA, yA, "A", "Feature-wise Similarity Profile") panel_label(xB, yB, "B", "Feature-Level Distribution Fidelity") panel_label(xC, yC, "C", "Dataset Shift Across Training Files") panel_label(xD, yD, "D", "Robustness Across Seeds") ax0 = xA + 22 ay0 = yA + 56 aw0 = panel_w - 44 ah0 = panel_h - 78 chart_pad_l = 54 chart_pad_r = 14 chart_pad_t = 24 chart_pad_b = 52 cx0 = ax0 + chart_pad_l cx1 = ax0 + aw0 - chart_pad_r cy0 = ay0 + chart_pad_t cy1 = ay0 + ah0 - chart_pad_b plot_side = min(cx1 - cx0, cy1 - cy0) cx1p = cx0 + plot_side parts.append(text(ax0, ay0 + 6, "Mean agreement (continuous features)", size=11, color=subtle, weight="bold")) parts.append(text(ax0 + aw0, ay0 + 6, "range-normalized by real min/max", size=10, anchor="end", color=subtle)) parts.append(round_box(cx0 - 10, cy0 - 10, (cx1p - cx0) + 20, (cy1 - cy0) + 20, fill=plot_bg, stroke=border, sw=1.0, rx=14)) parts.append(line(cx0, cy0, cx0, cy1, color=border, width=1.0, cap="butt")) parts.append(line(cx0, cy1, cx1p, cy1, color=border, width=1.0, cap="butt")) for t, lbl in [(0.0, "0.0"), (0.5, "0.5"), (1.0, "1.0")]: xx = cx0 + t * (cx1p - cx0) yy = cy1 - t * (cy1 - cy0) parts.append(line(xx, cy0, xx, cy1, color=grid, width=1.0, dash="4,6")) parts.append(line(cx0, yy, cx1p, yy, color=grid, width=1.0, dash="4,6")) parts.append(text(cx0 - 8, yy + 4, lbl, size=9, anchor="end", color=subtle)) parts.append(text(xx, cy1 + 22, lbl, size=9, anchor="middle", color=subtle)) parts.append( "" "Generated mean" "".format( x=ax0 + 10, y=(cy0 + cy1) / 2, ff="system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif", c=subtle, ) ) parts.append(text((cx0 + cx1p) / 2, cy1 + 48, "Real mean", size=10, anchor="middle", color=subtle, weight="bold")) legend_x = cx1p + 22 legend_y = cy0 + 6 legend_w = 226 legend_h = 74 parts.append(round_box(legend_x - 10, legend_y - 14, legend_w, legend_h, fill="#ffffff", stroke=border, sw=1.0, rx=12)) parts.append( "".format( x=legend_x + 4, y=legend_y + 4, c=blue ) ) parts.append(text(legend_x + 16, legend_y + 8, "KDE density", size=10, color=ink, weight="bold")) parts.append( "".format( x1=legend_x + 2, x2=legend_x + 18, y=legend_y + 26, c="#1d4ed8" ) ) parts.append(text(legend_x + 24, legend_y + 30, "Contours", size=10, color=ink, weight="bold")) parts.append( "".format( x1=legend_x + 2, x2=legend_x + 18, y=legend_y + 48, c="#94a3b8" ) ) parts.append(text(legend_x + 24, legend_y + 52, "y=x", size=10, color=ink, weight="bold")) if isinstance(profile, dict) and profile.get("points"): pts = profile["points"] stats = profile.get("stats") if isinstance(profile.get("stats"), dict) else {} n = stats.get("n") r = stats.get("r") mae = stats.get("mae") pad_px = 2.5 eps = pad_px / max(1e-9, (cx1p - cx0)) def X(v): return cx0 + clamp(v, eps, 1.0 - eps) * (cx1p - cx0) def Y(v): return cy1 - clamp(v, eps, 1.0 - eps) * (cy1 - cy0) parts.append( "".format( x=cx0, y=cy0, w=(cx1p - cx0), h=(cy1 - cy0), ) ) parts.append("") parts.append("") for p in pts: x = X(p["x"]) y = Y(p["y"]) parts.append("".format(x=x, y=y, c=blue)) parts.append("") parts.append(line(cx0, cy1, cx1p, cy0, color="#94a3b8", width=2.0, dash="6,6", opacity=0.9)) out_idx = sorted(range(len(pts)), key=lambda i: (pts[i].get("diff") or 0.0), reverse=True)[:5] grid_n = 64 bw = 0.085 bw2 = bw * bw kde = [[0.0 for _ in range(grid_n + 1)] for _ in range(grid_n + 1)] for iy in range(grid_n + 1): yy = iy / grid_n for ix in range(grid_n + 1): xx = ix / grid_n s = 0.0 for p in pts: dx = xx - clamp(p["x"], 0.0, 1.0) dy = yy - clamp(p["y"], 0.0, 1.0) s += math.exp(-(dx * dx + dy * dy) / (2.0 * bw2)) kde[iy][ix] = s vmax = max(max(row) for row in kde) if kde else 1.0 if vmax <= 0: vmax = 1.0 for iy in range(grid_n + 1): for ix in range(grid_n + 1): kde[iy][ix] /= vmax segments_by_level = { 1: [(3, 0)], 2: [(0, 1)], 3: [(3, 1)], 4: [(1, 2)], 5: [(3, 2), (0, 1)], 6: [(0, 2)], 7: [(3, 2)], 8: [(2, 3)], 9: [(0, 2)], 10: [(0, 3), (1, 2)], 11: [(1, 2)], 12: [(1, 3)], 13: [(0, 1)], 14: [(3, 0)], } def edge_pt(edge, x0, y0, x1, y1, v00, v10, v11, v01, level): if edge == 0: a, b = v00, v10 t = 0.5 if b == a else (level - a) / (b - a) return x0 + clamp(t, 0.0, 1.0) * (x1 - x0), y0 if edge == 1: a, b = v10, v11 t = 0.5 if b == a else (level - a) / (b - a) return x1, y0 + clamp(t, 0.0, 1.0) * (y1 - y0) if edge == 2: a, b = v01, v11 t = 0.5 if b == a else (level - a) / (b - a) return x0 + clamp(t, 0.0, 1.0) * (x1 - x0), y1 a, b = v00, v01 t = 0.5 if b == a else (level - a) / (b - a) return x0, y0 + clamp(t, 0.0, 1.0) * (y1 - y0) contour_levels = [0.18, 0.30, 0.42, 0.54, 0.66, 0.78] for li, lev in enumerate(contour_levels): segs = [] for iy in range(grid_n): y0n = iy / grid_n y1n = (iy + 1) / grid_n for ix in range(grid_n): x0n = ix / grid_n x1n = (ix + 1) / grid_n v00 = kde[iy][ix] v10 = kde[iy][ix + 1] v11 = kde[iy + 1][ix + 1] v01 = kde[iy + 1][ix] c0 = 1 if v00 >= lev else 0 c1 = 2 if v10 >= lev else 0 c2 = 4 if v11 >= lev else 0 c3 = 8 if v01 >= lev else 0 idx = c0 | c1 | c2 | c3 pairs = segments_by_level.get(idx) if not pairs: continue for e0, e1 in pairs: ax, ay = edge_pt(e0, x0n, y0n, x1n, y1n, v00, v10, v11, v01, lev) bx, by = edge_pt(e1, x0n, y0n, x1n, y1n, v00, v10, v11, v01, lev) segs.append((X(ax), Y(ay), X(bx), Y(by))) if segs: d = " ".join("M {x1:.1f} {y1:.1f} L {x2:.1f} {y2:.1f}".format(x1=a, y1=b, x2=c, y2=d) for a, b, c, d in segs) op = 0.14 + 0.09 * li sw = 0.9 + 0.18 * li parts.append( "".format( d=d, c="#1d4ed8", w=sw, o=op ) ) for i in out_idx: p = pts[i] x = X(p["x"]) y = Y(p["y"]) parts.append("".format(x=x, y=y, c=red)) parts.append("".format(x=x, y=y, c=red)) parts.append("") for i in out_idx: p = pts[i] x = X(p["x"]) y = Y(p["y"]) dx = 18 if p["y"] >= p["x"] else -18 if x <= cx0 + 16 and dx < 0: dx = 18 if x >= cx1p - 16 and dx > 0: dx = -18 dy = -12 if y <= cy0 + 12: dy = 14 anchor = "start" if dx > 0 else "end" parts.append(line(x, y, x + dx, y + dy + 2, color="#94a3b8", width=1.2, dash="3,5", opacity=0.9)) parts.append(text(x + dx, y + dy, p["feature"], size=9, anchor=anchor, color=ink, weight="bold")) s1 = "n={n}".format(n=n if isinstance(n, int) else len(pts)) s2 = "Pearson r={r:.3f}".format(r=r) if isinstance(r, float) else "Pearson r=NA" s3 = "MAE={m:.3f}".format(m=mae) if isinstance(mae, float) else "MAE=NA" parts.append( text( legend_x - 10, legend_y + legend_h + 18, s1 + " · " + s2 + " · " + s3, size=10, color=subtle, weight="bold", ) ) else: parts.append( text( (cx0 + cx1p) / 2, (cy0 + cy1) / 2, "missing cont_stats.json or generated.csv", size=12, anchor="middle", color=subtle, weight="bold", ) ) bx0 = xB + 22 by0 = yB + 62 bw0 = panel_w - 44 bh0 = panel_h - 86 ks_sorted = sorted( [ { "feature": r.get("feature", ""), "ks": parse_float(r.get("ks")), "gen_frac_at_min": parse_float(r.get("gen_frac_at_min")), "gen_frac_at_max": parse_float(r.get("gen_frac_at_max")), } for r in ks_rows if parse_float(r.get("ks")) is not None ], key=lambda x: x["ks"], reverse=True, ) top_n = 14 if len(ks_sorted) >= 14 else len(ks_sorted) top = ks_sorted[:top_n] dropped = [] if isinstance(filtered_metrics, dict): for d in filtered_metrics.get("dropped_features", []) or []: feat = d.get("feature") if feat: dropped.append(feat) parts.append(text(bx0, by0 - 8, "Top-{n} KS outliers (lower is better)".format(n=top_n), size=11, color=subtle)) if dropped: parts.append(text(bx0 + bw0, by0 - 8, "dropped: {d}".format(d=", ".join(dropped)), size=10, anchor="end", color=subtle)) chart_y0 = by0 + 16 chart_h = bh0 - 32 label_w = 180 x0 = bx0 + label_w x1 = bx0 + bw0 - 16 row_h = chart_h / max(1, top_n) for t in range(6): xx = x0 + (x1 - x0) * (t / 5) parts.append(line(xx, chart_y0, xx, chart_y0 + chart_h, color=grid, width=1.0)) parts.append(text(xx, chart_y0 + chart_h + 18, "{:.1f}".format(t / 5), size=9, anchor="middle", color=subtle)) parts.append(line(x0, chart_y0 + chart_h, x1, chart_y0 + chart_h, color=border, width=1.2, cap="butt")) for i, r in enumerate(top): fy = chart_y0 + i * row_h + row_h / 2 feat = r["feature"] ks = r["ks"] gmin = r["gen_frac_at_min"] or 0.0 gmax = r["gen_frac_at_max"] or 0.0 collapsed = (gmin >= 0.98) or (gmax >= 0.98) bar_color = "#0ea5e9" if not collapsed else "#fb7185" parts.append(text(x0 - 10, fy + 4, feat, size=9, anchor="end", color=ink)) parts.append(text(x1 + 8, fy + 4, "{:.3f}".format(ks), size=9, anchor="start", color=subtle)) w = (x1 - x0) * clamp(ks, 0.0, 1.0) parts.append( "".format( x=x0, y=fy - row_h * 0.34, w=w, h=row_h * 0.68, c=bar_color ) ) if collapsed: parts.append( text( x0 + min(w, (x1 - x0) - 10), fy + 4, "collapse", size=8, anchor="end", color="#7f1d1d", weight="bold", ) ) cx0 = xC + 22 cy0 = yC + 64 cw0 = panel_w - 44 ch0 = panel_h - 88 if shift_rows: cols = list(shift_rows[0].keys()) else: cols = [] mean_cols = [c for c in cols if c.startswith("mean_")] wanted = ["mean_P1_FT01", "mean_P1_LIT01", "mean_P1_PIT01", "mean_P2_CO_rpm", "mean_P3_LIT01", "mean_P4_ST_PT01"] feats = [c for c in wanted if c in mean_cols] files = [r.get("file", "") for r in shift_rows] sample_rows = [parse_float(r.get("sample_rows")) for r in shift_rows] feat_vals = {c: [parse_float(r.get(c)) or 0.0 for r in shift_rows] for c in feats} feat_z = {c: zscores(vs) for c, vs in feat_vals.items()} parts.append(text(cx0, cy0 - 10, "Mean shift (z-score) across train files", size=11, color=subtle)) if files: parts.append(text(cx0 + cw0, cy0 - 10, "rows: {r}".format(r=", ".join(str(int(x)) for x in sample_rows if x is not None)), size=10, anchor="end", color=subtle)) heat_x0 = cx0 + 160 heat_y0 = cy0 + 10 heat_w = cw0 - 180 heat_h = ch0 - 44 n_rows = max(1, len(files)) n_cols = max(1, len(feats)) cell_w = heat_w / n_cols cell_h = heat_h / n_rows for j, c in enumerate(feats): label = c.replace("mean_", "") parts.append(text(heat_x0 + j * cell_w + cell_w / 2, heat_y0 - 8, label, size=9, anchor="middle", color=ink, weight="bold")) for i, f in enumerate(files): fy = heat_y0 + i * cell_h + cell_h / 2 parts.append(text(cx0 + 140, fy + 4, f, size=9, anchor="end", color=ink)) for i in range(n_rows + 1): yy = heat_y0 + i * cell_h parts.append(line(heat_x0, yy, heat_x0 + heat_w, yy, color=border, width=1.0, cap="butt")) for j in range(n_cols + 1): xx = heat_x0 + j * cell_w parts.append(line(xx, heat_y0, xx, heat_y0 + heat_h, color=border, width=1.0, cap="butt")) for i in range(len(files)): for j, c in enumerate(feats): z = feat_z[c][i] if i < len(feat_z[c]) else 0.0 fill = diverging_color(z, vmin=-2.0, vmax=2.0) parts.append( "".format( x=heat_x0 + j * cell_w + 0.6, y=heat_y0 + i * cell_h + 0.6, w=cell_w - 1.2, h=cell_h - 1.2, c=fill ) ) parts.append(text(heat_x0 + j * cell_w + cell_w / 2, heat_y0 + i * cell_h + cell_h / 2 + 4, "{:+.2f}".format(z), size=9, anchor="middle", color=ink)) lx0 = heat_x0 ly0 = heat_y0 + heat_h + 18 parts.append(text(cx0 + 140, ly0 + 4, "z", size=9, anchor="end", color=subtle)) grad_w = 240 for k in range(25): t = k / 24 z = lerp(-2.0, 2.0, t) fill = diverging_color(z) parts.append( "".format( x=lx0 + k * (grad_w / 25), y=ly0 - 6, w=(grad_w / 25) + 0.2, c=fill ) ) parts.append(text(lx0, ly0 + 16, "-2", size=9, color=subtle)) parts.append(text(lx0 + grad_w / 2, ly0 + 16, "0", size=9, anchor="middle", color=subtle)) parts.append(text(lx0 + grad_w, ly0 + 16, "+2", size=9, anchor="end", color=subtle)) dx0 = xD + 22 dy0 = yD + 62 dw0 = panel_w - 44 dh0 = panel_h - 86 metrics = [ ("avg_ks", "KS (cont.)"), ("avg_jsd", "JSD (disc.)"), ("avg_lag1_diff", "Abs Δ lag-1"), ] bh_rows = sorted(bh_rows, key=lambda r: r.get("seed", 0)) seeds = [str(r.get("seed", "")) for r in bh_rows] parts.append(text(dx0, dy0 - 10, "Seed robustness (mean ± 1 std; dots: seeds)", size=11, color=subtle)) if seeds: parts.append(text(dx0 + dw0, dy0 - 10, "seeds: {s}".format(s=", ".join(seeds)), size=10, anchor="end", color=subtle)) spark_h = 48 spark_gap = 10 forest_y0 = dy0 + spark_h * 3 + spark_gap * 2 + 18 forest_h = dh0 - (spark_h * 3 + spark_gap * 2 + 28) hist_clean = [] for r in hist_rows: ks = parse_float(r.get("avg_ks")) jsd = parse_float(r.get("avg_jsd")) lag = parse_float(r.get("avg_lag1_diff")) if ks is None or jsd is None or lag is None: continue hist_clean.append({"avg_ks": ks, "avg_jsd": jsd, "avg_lag1_diff": lag}) if hist_clean: for mi, (k, title) in enumerate(metrics): y0 = dy0 + mi * (spark_h + spark_gap) y1 = y0 + spark_h parts.append(round_box(dx0, y0, dw0, spark_h, fill="#f9fafb", stroke=border, sw=1.0, rx=10)) parts.append(text(dx0 + 10, y0 + 18, title + " history", size=10, color=subtle, weight="bold")) vals = [r[k] for r in hist_clean] vmin = min(vals) vmax = max(vals) if vmax == vmin: vmax = vmin + 1.0 px0 = dx0 + 130 px1 = dx0 + dw0 - 12 py0 = y0 + 30 py1 = y1 - 10 parts.append(line(px0, py1, px1, py1, color=border, width=1.0, cap="butt")) parts.append(line(px0, py0, px0, py1, color=border, width=1.0, cap="butt")) pts = [] n = len(vals) for i, v in enumerate(vals): x = px0 + (px1 - px0) * (i / max(1, n - 1)) y = py1 - (v - vmin) * (py1 - py0) / (vmax - vmin) pts.append((x, y)) d = "M " + " L ".join("{:.1f} {:.1f}".format(x, y) for x, y in pts) parts.append("".format(d=d, c=blue if mi == 0 else (red if mi == 1 else green))) parts.append(text(dx0 + dw0 - 12, y0 + 18, "{:.3f}".format(vals[-1]), size=10, anchor="end", color=ink, weight="bold")) fx0 = dx0 fy0 = forest_y0 fw0 = dw0 fh0 = forest_h x0 = fx0 + 190 x1 = fx0 + fw0 - 18 for t in range(6): xx = x0 + (x1 - x0) * (t / 5) parts.append(line(xx, fy0, xx, fy0 + fh0, color=grid, width=1.0)) for mi, (key, title) in enumerate(metrics): y0 = fy0 + mi * (fh0 / 3) y1 = fy0 + (mi + 1) * (fh0 / 3) yc = (y0 + y1) / 2 vals = [r.get(key) for r in bh_rows if r.get(key) is not None] if not vals: continue m, s = mean_std(vals) vmin = min(vals + [m - s]) vmax = max(vals + [m + s]) if vmax == vmin: vmax = vmin + 1.0 vr = vmax - vmin vmin -= 0.15 * vr vmax += 0.15 * vr def X(v): return x0 + (v - vmin) * (x1 - x0) / (vmax - vmin) parts.append(text(x0 - 14, yc + 4, title, size=11, anchor="end", color=ink, weight="bold")) parts.append(line(x0, yc, x1, yc, color=border, width=1.2, cap="butt")) parts.append( "".format( x=X(m - s), y=yc - 10, w=max(1.0, X(m + s) - X(m - s)), c=red ) ) parts.append(line(X(m), yc - 14, X(m), yc + 14, color=red, width=2.4)) parts.append(text(x1, yc - 16, "mean={m:.4f}±{s:.4f}".format(m=m, s=s), size=9, anchor="end", color=subtle)) for i, v in enumerate(vals): jitter = ((i * 37) % 9 - 4) * 1.2 parts.append("".format(x=X(v), y=yc + jitter, c=blue)) parts.append("") out_path.parent.mkdir(parents=True, exist_ok=True) out_path.write_text("\n".join(parts), encoding="utf-8") def panel_matplotlib(bh_rows, ks_rows, shift_rows, hist_rows, filtered_metrics, profile, out_path): import matplotlib.pyplot as plt import matplotlib.patches as patches try: plt.style.use("seaborn-v0_8-whitegrid") except Exception: pass fig = plt.figure(figsize=(13.6, 8.6)) gs = fig.add_gridspec(2, 2, wspace=0.18, hspace=0.22) axA = fig.add_subplot(gs[0, 0]) axB = fig.add_subplot(gs[0, 1]) axC = fig.add_subplot(gs[1, 0]) axD = fig.add_subplot(gs[1, 1]) fig.suptitle("Benchmark Overview (HAI Security Dataset)", fontsize=16, y=0.98) axA.set_title("A Feature-wise Similarity Profile", loc="left", fontsize=12, fontweight="bold") if isinstance(profile, dict) and profile.get("points"): pts = profile["points"] xr = [p["x"] for p in pts] yg = [p["y"] for p in pts] diffs = [p.get("diff") or abs(p["x"] - p["y"]) for p in pts] out_idx = sorted(range(len(pts)), key=lambda i: diffs[i], reverse=True)[:5] axA.plot([0, 1], [0, 1], linestyle="--", color="#94a3b8", lw=1.8, dashes=(6, 6), label="y=x") try: hb = axA.hexbin( xr, yg, gridsize=26, extent=(0, 1, 0, 1), cmap="Blues", mincnt=1, linewidths=0.0, alpha=0.95, ) hb.set_edgecolor("face") except Exception: axA.scatter(xr, yg, s=22, color="#3b82f6", alpha=0.35, edgecolors="none") if out_idx: axA.scatter( [xr[i] for i in out_idx], [yg[i] for i in out_idx], s=90, facecolors="none", edgecolors="#ef4444", linewidths=2.0, alpha=0.95, ) axA.scatter([xr[i] for i in out_idx], [yg[i] for i in out_idx], s=22, color="#ef4444", alpha=0.95, edgecolors="white", linewidths=0.8) for i in out_idx: axA.annotate( pts[i]["feature"], (xr[i], yg[i]), textcoords="offset points", xytext=(12, -10), ha="left", va="top", fontsize=8, color="#111827", fontweight="bold", arrowprops=dict(arrowstyle="-", color="#94a3b8", lw=1.0, linestyle=(0, (3, 4))), ) axA.set_xlim(0.0, 1.0) axA.set_ylim(0.0, 1.0) axA.set_xlabel("Real mean (range-normalized)") axA.set_ylabel("Generated mean (range-normalized)") axA.set_aspect("equal", adjustable="box") axA.grid(True, color="#eef2f7") axA.legend(loc="lower right", frameon=False, fontsize=9) stats = profile.get("stats") if isinstance(profile.get("stats"), dict) else {} n = stats.get("n") if isinstance(stats.get("n"), int) else len(pts) r = stats.get("r") mae = stats.get("mae") s2 = "Pearson r={r:.3f}".format(r=r) if isinstance(r, float) else "Pearson r=NA" s3 = "MAE={m:.3f}".format(m=mae) if isinstance(mae, float) else "MAE=NA" axA.text(0.02, 0.98, "n={n} · {s2} · {s3}".format(n=n, s2=s2, s3=s3), transform=axA.transAxes, ha="left", va="top", fontsize=9, color="#6b7280", fontweight="bold") else: axA.axis("off") axA.text(0.5, 0.5, "missing cont_stats.json or generated.csv", ha="center", va="center", fontsize=12, color="#6b7280") axB.set_title("B Feature-Level Distribution Fidelity", loc="left", fontsize=12, fontweight="bold") ks_sorted = sorted( [ { "feature": r.get("feature", ""), "ks": parse_float(r.get("ks")), "gen_frac_at_min": parse_float(r.get("gen_frac_at_min")) or 0.0, "gen_frac_at_max": parse_float(r.get("gen_frac_at_max")) or 0.0, } for r in ks_rows if parse_float(r.get("ks")) is not None ], key=lambda x: x["ks"], reverse=True, ) top = ks_sorted[:14] feats = [r["feature"] for r in top][::-1] vals = [r["ks"] for r in top][::-1] collapsed = [((r["gen_frac_at_min"] >= 0.98) or (r["gen_frac_at_max"] >= 0.98)) for r in top][::-1] colors = ["#fb7185" if c else "#0ea5e9" for c in collapsed] axB.barh(feats, vals, color=colors) axB.set_xlabel("KS (lower is better)") axB.set_xlim(0, 1.0) if isinstance(filtered_metrics, dict) and filtered_metrics.get("dropped_features"): dropped = ", ".join(d.get("feature", "") for d in filtered_metrics["dropped_features"] if d.get("feature")) if dropped: axB.text(0.99, 0.02, "dropped: {d}".format(d=dropped), transform=axB.transAxes, ha="right", va="bottom", fontsize=9, color="#6b7280") axC.set_title("C Dataset Shift Across Training Files", loc="left", fontsize=12, fontweight="bold") if shift_rows: cols = list(shift_rows[0].keys()) else: cols = [] mean_cols = [c for c in cols if c.startswith("mean_")] wanted = ["mean_P1_FT01", "mean_P1_LIT01", "mean_P1_PIT01", "mean_P2_CO_rpm", "mean_P3_LIT01", "mean_P4_ST_PT01"] feats = [c for c in wanted if c in mean_cols] files = [r.get("file", "") for r in shift_rows] M = [] for c in feats: M.append([parse_float(r.get(c)) or 0.0 for r in shift_rows]) if M and files and feats: Z = list(zip(*[zscores(col) for col in M], strict=True)) im = axC.imshow(Z, aspect="auto", cmap="coolwarm", vmin=-2, vmax=2) axC.set_yticks(range(len(files))) axC.set_yticklabels(files) axC.set_xticks(range(len(feats))) axC.set_xticklabels([f.replace("mean_", "") for f in feats], rotation=25, ha="right") axC.set_ylabel("Train file") axC.set_xlabel("Feature mean z-score") fig.colorbar(im, ax=axC, fraction=0.046, pad=0.04) else: axC.axis("off") axC.text(0.5, 0.5, "missing data_shift_stats.csv", ha="center", va="center", fontsize=11, color="#6b7280") axD.set_title("D Robustness Across Seeds", loc="left", fontsize=12, fontweight="bold") axD.axis("off") axD.set_xlim(0, 1) axD.set_ylim(0, 1) metrics = [("avg_ks", "KS (cont.)"), ("avg_jsd", "JSD (disc.)"), ("avg_lag1_diff", "Abs Δ lag-1")] bh_rows = sorted(bh_rows, key=lambda r: r.get("seed", 0)) for mi, (k, title) in enumerate(metrics): vals = [r.get(k) for r in bh_rows if r.get(k) is not None] if not vals: continue m, s = mean_std(vals) y = 0.78 - mi * 0.22 axD.text(0.04, y, title, fontsize=10, fontweight="bold", va="center") x0 = 0.42 x1 = 0.96 vmin = min(vals + [m - s]) vmax = max(vals + [m + s]) if vmax == vmin: vmax = vmin + 1.0 vr = vmax - vmin vmin -= 0.15 * vr vmax += 0.15 * vr def X(v): return x0 + (v - vmin) * (x1 - x0) / (vmax - vmin) axD.add_patch(patches.FancyBboxPatch((X(m - s), y - 0.03), max(0.002, X(m + s) - X(m - s)), 0.06, boxstyle="round,pad=0.01,rounding_size=0.02", facecolor="#ef4444", alpha=0.12, edgecolor="none")) axD.plot([X(m), X(m)], [y - 0.05, y + 0.05], color="#ef4444", lw=2.2) jit = [-0.03, 0.0, 0.03] for i, v in enumerate(vals): axD.scatter([X(v)], [y + jit[i % len(jit)]], s=40, color="#3b82f6", edgecolor="white", linewidth=0.9, zorder=3) axD.text(0.96, y + 0.07, "mean={m:.4f}±{s:.4f}".format(m=m, s=s), fontsize=9, color="#6b7280", ha="right") fig.tight_layout(rect=(0, 0, 1, 0.96)) out_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_path, format="svg") plt.close(fig) def main(): args = parse_args() hist_path = Path(args.history) if not hist_path.exists(): raise SystemExit("missing history file: %s" % hist_path) rows = [] with hist_path.open("r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) for r in reader: rows.append( { "run_name": r["run_name"], "seed": int(r["seed"]), "avg_ks": float(r["avg_ks"]), "avg_jsd": float(r["avg_jsd"]), "avg_lag1_diff": float(r["avg_lag1_diff"]), } ) if not rows: raise SystemExit("no rows in history file: %s" % hist_path) rows = sorted(rows, key=lambda x: x["seed"]) seeds = [str(r["seed"]) for r in rows] metrics = [ ("avg_ks", "KS (continuous)"), ("avg_jsd", "JSD (discrete)"), ("avg_lag1_diff", "Abs Δ lag-1 autocorr"), ] if args.out: out_path = Path(args.out) else: if args.figure == "panel": out_path = Path(__file__).resolve().parent.parent / "figures" / "benchmark_panel.svg" elif args.figure == "summary": out_path = Path(__file__).resolve().parent.parent / "figures" / "benchmark_metrics.svg" else: out_path = Path(__file__).resolve().parent.parent / "figures" / "ranked_ks.svg" if args.figure == "summary": if args.engine in {"auto", "matplotlib"}: try: plot_matplotlib(rows, seeds, metrics, out_path) print("saved", out_path) return except Exception: if args.engine == "matplotlib": raise plot_svg(rows, seeds, metrics, out_path) print("saved", out_path) return if args.figure == "ranked_ks": ranked_rows = read_csv_rows(args.ranked_ks) ranked_ks_svg(ranked_rows, out_path, top_n=args.ranked_ks_top_n) print("saved", out_path) return if args.figure == "lines": feats_arg = [f.strip() for f in (args.lines_features or "").split(",") if f.strip()] features = feats_arg if not features: rk_rows = read_csv_rows(args.ranked_ks) if rk_rows: sorted_rows = sorted( [{"feature": (r.get("feature") or "").strip(), "ks": parse_float(r.get("ks"))} for r in rk_rows if (r.get("feature") or "").strip()], key=lambda x: (x["ks"] if x["ks"] is not None else -1.0), reverse=True, ) features = [r["feature"] for r in sorted_rows[:max(1, int(args.lines_top_k))]] if not features: features = ["P1_B4002", "P1_PIT02", "P1_FCV02Z", "P1_B3004"] if not args.out: out_path = Path(__file__).resolve().parent.parent / "figures" / "lines.svg" cont_stats = read_json(args.cont_stats) lines_matplotlib(args.generated, cont_stats, features, out_path, max_rows=args.lines_max_rows, normalize=args.lines_normalize, reference_arg=args.reference, ref_index=args.lines_ref_index) print("saved", out_path) return if args.figure == "cdf_grid": cont_stats = read_json(args.cont_stats) feats_arg = [f.strip() for f in (args.cdf_features or "").split(",") if f.strip()] if feats_arg: features = feats_arg else: features = sorted(list((cont_stats.get("mean") or {}).keys())) if args.cdf_max_features > 0: features = features[: args.cdf_max_features] if not args.out: out_path = Path(__file__).resolve().parent.parent / "figures" / "cdf_grid.svg" cdf_grid_matplotlib(args.generated, args.reference, cont_stats, features, out_path, max_rows=max(1000, args.lines_max_rows), bins=args.cdf_bins) print("saved", out_path) return if args.figure == "disc_grid": split = read_json(args.feature_split) disc_list = list((split.get("discrete") or [])) if isinstance(split, dict) else [] feats_arg = [f.strip() for f in (args.disc_features or "").split(",") if f.strip()] if feats_arg: features = feats_arg else: features = sorted(disc_list) if args.disc_max_features > 0: features = features[: args.disc_max_features] if not args.out: out_path = Path(__file__).resolve().parent.parent / "figures" / "disc_grid.svg" discrete_grid_matplotlib(args.generated, args.reference, features, out_path, max_rows=max(1000, args.lines_max_rows)) print("saved", out_path) return if args.figure == "disc_points": split = read_json(args.feature_split) disc_list = list((split.get("discrete") or [])) if isinstance(split, dict) else [] feats_arg = [f.strip() for f in (args.disc_features or "").split(",") if f.strip()] if feats_arg: features = feats_arg else: features = sorted(disc_list) if args.disc_max_features > 0: features = features[: args.disc_max_features] if not args.out: out_path = Path(__file__).resolve().parent.parent / "figures" / "disc_points.svg" discrete_points_matplotlib(args.generated, args.reference, features, out_path, max_rows=max(1000, args.lines_max_rows)) print("saved", out_path) return ks_rows = read_csv_rows(args.ks_per_feature) shift_rows = read_csv_rows(args.data_shift) mh_rows = read_csv_rows(args.metrics_history) fm = read_json(args.filtered_metrics) cont_stats = read_json(args.cont_stats) profile = build_mean_profile(ks_rows, cont_stats, args.generated, order=args.profile_order, max_features=args.profile_max_features) bh_rows = [{"seed": r["seed"], "avg_ks": r["avg_ks"], "avg_jsd": r["avg_jsd"], "avg_lag1_diff": r["avg_lag1_diff"]} for r in rows] if args.engine in {"auto", "matplotlib"}: try: panel_matplotlib(bh_rows, ks_rows, shift_rows, mh_rows, fm, profile, out_path) print("saved", out_path) return except Exception: if args.engine == "matplotlib": raise panel_svg(bh_rows, ks_rows, shift_rows, mh_rows, fm, profile, out_path) print("saved", out_path) if __name__ == "__main__": main()