#!/usr/bin/env python3
import argparse
import csv
import json
import math
from pathlib import Path
def parse_args():
parser = argparse.ArgumentParser(description="Plot benchmark metrics from benchmark_history.csv")
base_dir = Path(__file__).resolve().parent
parser.add_argument(
"--figure",
choices=["panel", "summary", "ranked_ks", "lines", "cdf_grid", "disc_grid", "disc_points"],
default="panel",
help="Figure type: panel (paper-style multi-panel), summary (seed robustness only), ranked_ks (outlier attribution), lines (feature series), or cdf_grid (distributions).",
)
parser.add_argument(
"--history",
default=str(base_dir / "results" / "benchmark_history.csv"),
help="Path to benchmark_history.csv",
)
parser.add_argument(
"--generated",
default=str(base_dir / "results" / "generated.csv"),
help="Path to generated.csv (used for per-feature profile in panel A).",
)
parser.add_argument(
"--cont-stats",
default=str(base_dir / "results" / "cont_stats.json"),
help="Path to cont_stats.json (used for per-feature profile in panel A).",
)
parser.add_argument(
"--profile-order",
choices=["ks_desc", "name"],
default="ks_desc",
help="Feature ordering for panel A profile: ks_desc or name.",
)
parser.add_argument(
"--profile-max-features",
type=int,
default=64,
help="Max features in panel A profile (0 means all).",
)
parser.add_argument(
"--ks-per-feature",
default=str(base_dir / "results" / "ks_per_feature.csv"),
help="Path to ks_per_feature.csv",
)
parser.add_argument(
"--data-shift",
default=str(base_dir / "results" / "data_shift_stats.csv"),
help="Path to data_shift_stats.csv",
)
parser.add_argument(
"--metrics-history",
default=str(base_dir / "results" / "metrics_history.csv"),
help="Path to metrics_history.csv",
)
parser.add_argument(
"--filtered-metrics",
default=str(base_dir / "results" / "filtered_metrics.json"),
help="Path to filtered_metrics.json (optional).",
)
parser.add_argument(
"--ranked-ks",
default=str(base_dir / "results" / "ranked_ks.csv"),
help="Path to ranked_ks.csv (used for --figure ranked_ks).",
)
parser.add_argument(
"--ranked-ks-top-n",
type=int,
default=30,
help="Number of top KS features to show in ranked_ks figure.",
)
parser.add_argument(
"--out",
default="",
help="Output SVG path (default depends on --figure).",
)
parser.add_argument(
"--engine",
choices=["auto", "matplotlib", "svg"],
default="auto",
help="Plotting engine: auto prefers matplotlib if available; otherwise uses pure-SVG.",
)
parser.add_argument(
"--lines-features",
default="",
help="Comma-separated feature names for --figure lines (default: top-4 from ranked_ks.csv or fallback set).",
)
parser.add_argument(
"--lines-top-k",
type=int,
default=8,
help="When features not specified, take top-K features from ranked_ks.csv.",
)
parser.add_argument(
"--lines-max-rows",
type=int,
default=1000,
help="Max rows to read from generated.csv for --figure lines.",
)
parser.add_argument(
"--lines-normalize",
choices=["none", "real_range"],
default="none",
help="Normalization for --figure lines: none or real_range (use cont_stats min/max).",
)
parser.add_argument(
"--reference",
default=str(base_dir / "config.json"),
help="Reference source for real data: config.json with data_glob or direct CSV/GZ path.",
)
parser.add_argument(
"--lines-ref-index",
type=int,
default=0,
help="Index of matched reference file to plot (0-based) when using a glob.",
)
parser.add_argument(
"--cdf-features",
default="",
help="Comma-separated features for cdf_grid; empty = all continuous features from cont_stats.",
)
parser.add_argument(
"--cdf-max-features",
type=int,
default=64,
help="Max features to include in cdf_grid.",
)
parser.add_argument(
"--cdf-bins",
type=int,
default=80,
help="Number of bins for empirical CDF.",
)
parser.add_argument(
"--disc-features",
default="",
help="Comma-separated features for disc_grid; empty = use discrete list from feature_split.json.",
)
parser.add_argument(
"--disc-max-features",
type=int,
default=64,
help="Max features to include in disc_grid.",
)
parser.add_argument(
"--feature-split",
default=str(base_dir / "feature_split.json"),
help="Path to feature_split.json with 'continuous' and 'discrete' lists.",
)
return parser.parse_args()
def mean_std(vals):
m = sum(vals) / len(vals)
if len(vals) == 1:
return m, 0.0
v = sum((x - m) * (x - m) for x in vals) / (len(vals) - 1)
return m, math.sqrt(v)
def svg_escape(s):
return (
str(s)
.replace("&", "&")
.replace("<", "<")
.replace(">", ">")
.replace('"', """)
.replace("'", "'")
)
def clamp(v, lo, hi):
if v < lo:
return lo
if v > hi:
return hi
return v
def lerp(a, b, t):
return a + (b - a) * t
def hex_to_rgb(h):
h = h.lstrip("#")
return int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
def rgb_to_hex(r, g, b):
return "#{:02x}{:02x}{:02x}".format(int(clamp(r, 0, 255)), int(clamp(g, 0, 255)), int(clamp(b, 0, 255)))
def diverging_color(v, vmin=-2.0, vmax=2.0, cold="#2563eb", hot="#ef4444", mid="#ffffff"):
v = clamp(v, vmin, vmax)
if v >= 0:
t = 0.0 if vmax == 0 else v / vmax
r0, g0, b0 = hex_to_rgb(mid)
r1, g1, b1 = hex_to_rgb(hot)
return rgb_to_hex(lerp(r0, r1, t), lerp(g0, g1, t), lerp(b0, b1, t))
t = 0.0 if vmin == 0 else (-v) / (-vmin)
r0, g0, b0 = hex_to_rgb(mid)
r1, g1, b1 = hex_to_rgb(cold)
return rgb_to_hex(lerp(r0, r1, t), lerp(g0, g1, t), lerp(b0, b1, t))
def plot_matplotlib(rows, seeds, metrics, out_path):
import matplotlib.pyplot as plt
try:
plt.style.use("seaborn-v0_8-whitegrid")
except Exception:
pass
fig, axes = plt.subplots(nrows=len(metrics), ncols=1, figsize=(8.6, 4.6), sharex=False)
if len(metrics) == 1:
axes = [axes]
point_color = "#3b82f6"
band_color = "#ef4444"
grid_color = "#e5e7eb"
axis_color = "#111827"
for ax, (key, title) in zip(axes, metrics, strict=True):
vals = [r[key] for r in rows]
m, s = mean_std(vals)
vmin = min(vals + [m - s])
vmax = max(vals + [m + s])
if vmax == vmin:
vmax = vmin + 1.0
vr = vmax - vmin
vmin -= 0.20 * vr
vmax += 0.20 * vr
y0 = 0.0
jitter = [-0.08, 0.0, 0.08]
ys = [(y0 + jitter[i % len(jitter)]) for i in range(len(vals))]
ax.axvspan(m - s, m + s, color=band_color, alpha=0.10, linewidth=0)
ax.axvline(m, color=band_color, linewidth=2.2)
ax.scatter(vals, ys, s=46, color=point_color, edgecolors="white", linewidths=1.0, zorder=3)
for x, y, seed in zip(vals, ys, seeds, strict=True):
ax.annotate(
str(seed),
(x, y),
textcoords="offset points",
xytext=(0, 10),
ha="center",
va="bottom",
fontsize=8,
color=axis_color,
)
ax.set_title(title, loc="left", fontsize=11, color=axis_color, pad=8)
ax.set_yticks([])
ax.set_ylim(-0.35, 0.35)
ax.set_xlim(vmin, vmax)
ax.grid(True, axis="x", color=grid_color)
ax.grid(False, axis="y")
ax.text(
0.99,
0.80,
"mean={m:.4f} ± {s:.4f}".format(m=m, s=s),
transform=ax.transAxes,
ha="right",
va="center",
fontsize=9,
color="#374151",
)
fig.suptitle("Benchmark Metrics (3 seeds) · lower is better", fontsize=12, color=axis_color, y=0.98)
fig.tight_layout(rect=(0, 0, 1, 0.95))
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, format="svg")
plt.close(fig)
def discrete_points_matplotlib(generated_csv_path, reference_arg, features, out_path, max_rows=5000):
import matplotlib.pyplot as plt
import numpy as np
import gzip
try:
plt.style.use("seaborn-v0_8-whitegrid")
except Exception:
pass
def resolve_reference_glob(ref_arg: str) -> str:
ref_path = Path(ref_arg)
if ref_path.suffix == ".json":
cfg = json.loads(ref_path.read_text(encoding="utf-8"))
data_glob = cfg.get("data_glob") or cfg.get("data_path") or ""
if not data_glob:
raise SystemExit("reference config has no data_glob/data_path")
combined = ref_path.parent / data_glob
if "*" in str(combined) or "?" in str(combined):
return str(combined)
return str(combined.resolve())
return str(ref_path)
def read_rows(path, limit):
rows = []
opener = gzip.open if str(path).endswith(".gz") else open
with opener(path, "rt", newline="") as fh:
reader = csv.DictReader(fh)
for i, r in enumerate(reader):
rows.append(r)
if limit > 0 and i + 1 >= limit:
break
return rows
def cats(rows, feat):
vs = []
for r in rows:
v = r.get(feat)
if v is None:
continue
s = str(v).strip()
if s == "" or s.lower() == "nan":
continue
vs.append(s)
return vs
ref_glob = resolve_reference_glob(reference_arg)
ref_paths = sorted(Path(ref_glob).parent.glob(Path(ref_glob).name))
gen_rows = read_rows(generated_csv_path, max_rows)
ref_rows = read_rows(ref_paths[0] if ref_paths else generated_csv_path, max_rows)
points = []
group_spans = []
x = 0
for feat in features:
gvs = cats(gen_rows, feat)
rvs = cats(ref_rows, feat)
cats_all = sorted(set(gvs) | set(rvs))
start_x = x
for c in cats_all:
g_count = sum(1 for v in gvs if v == c)
r_count = sum(1 for v in rvs if v == c)
g_total = len(gvs) or 1
r_total = len(rvs) or 1
g_p = g_count / g_total
r_p = r_count / r_total
points.append({"x": x, "feat": feat, "cat": c, "g": g_p, "r": r_p})
x += 1
end_x = x - 1
if end_x >= start_x:
group_spans.append({"feat": feat, "x0": start_x, "x1": end_x})
x += 1
if not points:
fig, ax = plt.subplots(figsize=(9, 3))
ax.axis("off")
ax.text(0.5, 0.5, "no discrete data", ha="center", va="center")
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, format="svg")
plt.close(fig)
return
width = max(9.0, min(18.0, 0.08 * len(points)))
fig, ax = plt.subplots(figsize=(width, 6.2))
xs = np.array([p["x"] for p in points], dtype=float)
jitter = 0.18
ax.scatter(xs - jitter, [p["g"] for p in points], s=26, color="#2563eb", alpha=0.85, edgecolors="white", linewidths=0.6, label="generated")
ax.scatter(xs + jitter, [p["r"] for p in points], s=26, color="#ef4444", alpha=0.75, edgecolors="white", linewidths=0.6, label="real")
for span in group_spans:
xc = (span["x0"] + span["x1"]) / 2.0
ax.axvline(span["x1"] + 0.5, color="#e5e7eb", lw=1.0)
ax.text(xc, -0.06, span["feat"], ha="center", va="top", rotation=25, fontsize=8, color="#374151")
xg_all = xs - jitter
yg_all = [p["g"] for p in points]
xr_all = xs + jitter
yr_all = [p["r"] for p in points]
ax.plot(xg_all, yg_all, color="#2563eb", linewidth=1.2, alpha=0.85)
ax.plot(xr_all, yr_all, color="#ef4444", linewidth=1.2, alpha=0.75)
ax.fill_between(xg_all, yg_all, 0.0, color="#2563eb", alpha=0.14, step=None)
ax.fill_between(xr_all, yr_all, 0.0, color="#ef4444", alpha=0.14, step=None)
ax.set_ylim(-0.08, 1.02)
ax.set_xlim(-0.5, max(xs) + 0.5)
ax.set_ylabel("probability", fontsize=10)
ax.set_xticks([])
ax.grid(True, axis="y", color="#e5e7eb")
ax.legend(loc="upper right", fontsize=9)
fig.suptitle("Discrete Marginals (dot plot): generated vs real", fontsize=12, color="#111827", y=0.98)
fig.tight_layout(rect=(0, 0, 1, 0.96))
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, format="svg")
plt.close(fig)
def plot_svg(rows, seeds, metrics, out_path):
W, H = 980, 440
pad_l, pad_r, pad_t, pad_b = 200, 30, 74, 36
row_gap = 52
row_h = (H - pad_t - pad_b - row_gap * (len(metrics) - 1)) / len(metrics)
bg = "#ffffff"
axis = "#2b2b2b"
grid = "#e9e9e9"
band = "#d62728"
band_fill = "#d62728"
point = "#1f77b4"
text = "#111111"
subtle = "#666666"
parts = []
parts.append(
"")
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text("\n".join(parts), encoding="utf-8")
def ranked_ks_svg(ranked_rows, out_path, top_n=30):
parsed = []
for r in ranked_rows or []:
feat = (r.get("feature") or "").strip()
if not feat:
continue
ks = parse_float(r.get("ks"))
if ks is None:
continue
rank = parse_float(r.get("rank"))
contrib = parse_float(r.get("contribution_to_avg"))
avg_if = parse_float(r.get("avg_ks_if_remove_top_n"))
parsed.append(
{
"rank": int(rank) if isinstance(rank, (int, float)) else None,
"feature": feat,
"ks": float(ks),
"contrib": float(contrib) if isinstance(contrib, (int, float)) else None,
"avg_if": float(avg_if) if isinstance(avg_if, (int, float)) else None,
}
)
if not parsed:
raise SystemExit("no valid rows in ranked_ks.csv")
parsed_by_ks = sorted(parsed, key=lambda x: x["ks"], reverse=True)
top_n = max(1, int(top_n))
top = parsed_by_ks[: min(top_n, len(parsed_by_ks))]
parsed_by_rank = sorted(
[p for p in parsed if isinstance(p.get("rank"), int) and p.get("avg_if") is not None],
key=lambda x: x["rank"],
)
baseline = None
if parsed_by_rank:
first = parsed_by_rank[0]
if first.get("avg_if") is not None and first.get("contrib") is not None:
baseline = first["avg_if"] + first["contrib"]
elif first.get("avg_if") is not None:
baseline = first["avg_if"]
if baseline is None:
baseline = sum(p["ks"] for p in parsed_by_ks) / len(parsed_by_ks)
xs = [0]
ys = [baseline]
for p in parsed_by_rank:
xs.append(p["rank"])
ys.append(p["avg_if"])
W, H = 980, 560
bg = "#ffffff"
ink = "#0f172a"
subtle = "#64748b"
grid = "#e2e8f0"
border = "#cbd5e1"
blue = "#2563eb"
bar = "#0ea5e9"
ff = "system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif"
def text(x, y, s, size=12, anchor="start", color=ink, weight="normal"):
return "{t}".format(
x=x,
y=y,
a=anchor,
ff=ff,
fs=size,
c=color,
w=weight,
t=svg_escape(s),
)
def line(x1, y1, x2, y2, color=grid, width=1.0, dash=None, opacity=1.0, cap="round"):
extra = ""
if dash:
extra += " stroke-dasharray='{d}'".format(d=dash)
if opacity != 1.0:
extra += " stroke-opacity='{o}'".format(o=opacity)
return "".format(
x1=x1,
y1=y1,
x2=x2,
y2=y2,
c=color,
w=width,
cap=cap,
extra=extra,
)
pad = 26
title_h = 62
gap = 18
card_w = (W - 2 * pad - gap) / 2
card_h = H - pad - title_h
xA = pad
xB = pad + card_w + gap
y0 = title_h
parts = []
parts.append("")
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text("\n".join(parts), encoding="utf-8")
def lines_matplotlib(generated_csv_path, cont_stats, features, out_path, max_rows=1000, normalize="none", reference_arg="", ref_index=0):
import matplotlib.pyplot as plt
import gzip
try:
plt.style.use("seaborn-v0_8-whitegrid")
except Exception:
pass
rows_gen = []
with Path(generated_csv_path).open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for i, r in enumerate(reader):
rows_gen.append(r)
if len(rows_gen) >= max_rows:
break
xs_gen = [parse_float(r.get("time")) or (i if (r.get("time") is None) else 0.0) for i, r in enumerate(rows_gen)]
mins = {}
maxs = {}
if isinstance(cont_stats, dict):
mins = cont_stats.get("min", {}) or {}
maxs = cont_stats.get("max", {}) or {}
def norm_val(feat, v):
if normalize != "real_range":
return v
mn = parse_float(mins.get(feat))
mx = parse_float(maxs.get(feat))
if mn is None or mx is None:
return v
denom = mx - mn
if denom == 0:
return v
return (v - mn) / denom
def resolve_reference_glob(ref_arg: str) -> str:
ref_path = Path(ref_arg)
if ref_path.suffix == ".json":
cfg = json.loads(ref_path.read_text(encoding="utf-8"))
data_glob = cfg.get("data_glob") or cfg.get("data_path") or ""
if not data_glob:
raise SystemExit("reference config has no data_glob/data_path")
combined = ref_path.parent / data_glob
if "*" in str(combined) or "?" in str(combined):
return str(combined)
return str(combined.resolve())
return str(ref_path)
def read_series(path: Path, cols, max_rows: int):
vals = {c: [] for c in cols}
opener = gzip.open if str(path).endswith(".gz") else open
with opener(path, "rt", newline="") as fh:
reader = csv.DictReader(fh)
for i, row in enumerate(reader):
for c in cols:
try:
vals[c].append(float(row[c]))
except Exception:
pass
if max_rows > 0 and i + 1 >= max_rows:
break
return vals
ref_glob = resolve_reference_glob(reference_arg or str(Path(__file__).resolve().parent / "config.json"))
ref_paths = sorted(Path(ref_glob).parent.glob(Path(ref_glob).name))
ref_rows = []
if ref_paths:
idx = max(0, min(ref_index, len(ref_paths) - 1))
first = ref_paths[idx]
opener = gzip.open if str(first).endswith(".gz") else open
with opener(first, "rt", newline="") as fh:
reader = csv.DictReader(fh)
for i, r in enumerate(reader):
ref_rows.append(r)
if len(ref_rows) >= max_rows:
break
xs_ref = [i for i, _ in enumerate(ref_rows)]
fig, axes = plt.subplots(nrows=len(features), ncols=1, figsize=(9.2, 6.6), sharex=True)
if len(features) == 1:
axes = [axes]
for ax, feat in zip(axes, features, strict=True):
ys_gen = [norm_val(feat, parse_float(r.get(feat)) or 0.0) for r in rows_gen]
ax.plot(xs_gen, ys_gen, color="#2563eb", linewidth=1.6, label="generated")
if ref_rows:
ys_ref = [norm_val(feat, parse_float(r.get(feat)) or 0.0) for r in ref_rows]
ax.plot(xs_ref, ys_ref, color="#ef4444", linewidth=1.2, alpha=0.75, label="real")
ax.set_ylabel(feat, fontsize=10)
ax.grid(True, color="#e5e7eb")
ax.legend(loc="upper right", fontsize=8)
axes[-1].set_xlabel("time", fontsize=10)
fig.suptitle("Feature Series: generated vs real", fontsize=12, color="#111827", y=0.98)
fig.tight_layout(rect=(0, 0, 1, 0.96))
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, format="svg")
plt.close(fig)
def cdf_grid_matplotlib(generated_csv_path, reference_arg, cont_stats, features, out_path, max_rows=5000, bins=80):
import matplotlib.pyplot as plt
import gzip
import numpy as np
try:
plt.style.use("seaborn-v0_8-whitegrid")
except Exception:
pass
mins = cont_stats.get("min", {}) if isinstance(cont_stats, dict) else {}
maxs = cont_stats.get("max", {}) if isinstance(cont_stats, dict) else {}
def resolve_reference_glob(ref_arg: str) -> str:
ref_path = Path(ref_arg)
if ref_path.suffix == ".json":
cfg = json.loads(ref_path.read_text(encoding="utf-8"))
data_glob = cfg.get("data_glob") or cfg.get("data_path") or ""
if not data_glob:
raise SystemExit("reference config has no data_glob/data_path")
combined = ref_path.parent / data_glob
if "*" in str(combined) or "?" in str(combined):
return str(combined)
return str(combined.resolve())
return str(ref_path)
def read_rows(path, limit):
rows = []
opener = gzip.open if str(path).endswith(".gz") else open
with opener(path, "rt", newline="") as fh:
reader = csv.DictReader(fh)
for i, r in enumerate(reader):
rows.append(r)
if limit > 0 and i + 1 >= limit:
break
return rows
gen_rows = read_rows(generated_csv_path, max_rows)
ref_glob = resolve_reference_glob(reference_arg)
ref_paths = sorted(Path(ref_glob).parent.glob(Path(ref_glob).name))
ref_rows = read_rows(ref_paths[0] if ref_paths else generated_csv_path, max_rows)
def values(rows, feat):
vs = []
for r in rows:
x = parse_float(r.get(feat))
if x is not None:
vs.append(x)
return vs
cols = 4
rows_n = int(math.ceil(len(features) / cols)) if features else 1
fig, axes = plt.subplots(nrows=rows_n, ncols=cols, figsize=(cols * 3.2, rows_n * 2.6))
axes = np.array(axes).reshape(rows_n, cols)
for i, feat in enumerate(features):
rr = i // cols
cc = i % cols
ax = axes[rr][cc]
gvs = values(gen_rows, feat)
rvs = values(ref_rows, feat)
mn = parse_float(mins.get(feat))
mx = parse_float(maxs.get(feat))
if mn is None or mx is None or mx <= mn:
lo = min(gvs + rvs) if (gvs or rvs) else 0.0
hi = max(gvs + rvs) if (gvs or rvs) else (lo + 1.0)
else:
lo = mn
hi = mx
edges = np.linspace(lo, hi, bins + 1)
def ecdf(vs):
if not vs:
return edges[1:], np.zeros_like(edges[1:])
hist, _ = np.histogram(vs, bins=edges)
cdf = np.cumsum(hist).astype(float)
cdf /= cdf[-1] if cdf[-1] > 0 else 1.0
xs = edges[1:]
return xs, cdf
xg, yg = ecdf(gvs)
xr, yr = ecdf(rvs)
ax.plot(xg, yg, color="#2563eb", linewidth=1.6, label="generated")
ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real")
ax.set_title(feat, fontsize=9, loc="left")
ax.set_ylim(0, 1)
ax.grid(True, color="#e5e7eb")
for j in range(i + 1, rows_n * cols):
rr = j // cols
cc = j % cols
axes[rr][cc].axis("off")
handles, labels = axes[0][0].get_legend_handles_labels()
fig.legend(handles, labels, loc="upper center", ncol=2, fontsize=9)
fig.suptitle("Empirical CDF: generated vs real", fontsize=12, color="#111827", y=0.98)
fig.tight_layout(rect=(0, 0, 1, 0.96))
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, format="svg")
plt.close(fig)
def discrete_grid_matplotlib(generated_csv_path, reference_arg, features, out_path, max_rows=5000):
import matplotlib.pyplot as plt
import numpy as np
import gzip
try:
plt.style.use("seaborn-v0_8-whitegrid")
except Exception:
pass
def resolve_reference_glob(ref_arg: str) -> str:
ref_path = Path(ref_arg)
if ref_path.suffix == ".json":
cfg = json.loads(ref_path.read_text(encoding="utf-8"))
data_glob = cfg.get("data_glob") or cfg.get("data_path") or ""
if not data_glob:
raise SystemExit("reference config has no data_glob/data_path")
combined = ref_path.parent / data_glob
if "*" in str(combined) or "?" in str(combined):
return str(combined)
return str(combined.resolve())
return str(ref_path)
def read_rows(path, limit):
rows = []
opener = gzip.open if str(path).endswith(".gz") else open
with opener(path, "rt", newline="") as fh:
reader = csv.DictReader(fh)
for i, r in enumerate(reader):
rows.append(r)
if limit > 0 and i + 1 >= limit:
break
return rows
ref_glob = resolve_reference_glob(reference_arg)
ref_paths = sorted(Path(ref_glob).parent.glob(Path(ref_glob).name))
gen_rows = read_rows(generated_csv_path, max_rows)
ref_rows = read_rows(ref_paths[0] if ref_paths else generated_csv_path, max_rows)
def cats(rows, feat):
vs = []
for r in rows:
v = r.get(feat)
if v is None:
continue
s = str(v).strip()
if s == "" or s.lower() == "nan":
continue
vs.append(s)
return vs
cols = 4
rows_n = int(math.ceil(len(features) / cols)) if features else 1
fig, axes = plt.subplots(nrows=rows_n, ncols=cols, figsize=(cols * 3.0, rows_n * 2.6), sharey=False)
axes = np.array(axes).reshape(rows_n, cols)
for i, feat in enumerate(features):
rr = i // cols
cc = i % cols
ax = axes[rr][cc]
gvs = cats(gen_rows, feat)
rvs = cats(ref_rows, feat)
all_vals = sorted(set(gvs) | set(rvs))
if not all_vals:
ax.axis("off")
continue
g_counts = {v: 0 for v in all_vals}
r_counts = {v: 0 for v in all_vals}
for v in gvs:
g_counts[v] += 1
for v in rvs:
r_counts[v] += 1
g_total = sum(g_counts.values()) or 1
r_total = sum(r_counts.values()) or 1
g_p = [g_counts[v] / g_total for v in all_vals]
r_p = [r_counts[v] / r_total for v in all_vals]
x = np.arange(len(all_vals))
w = 0.42
ax.bar(x - w / 2, g_p, width=w, color="#2563eb", alpha=0.85, label="generated")
ax.bar(x + w / 2, r_p, width=w, color="#ef4444", alpha=0.75, label="real")
ax.set_title(feat, fontsize=9, loc="left")
ax.set_xticks(x)
ax.set_xticklabels(all_vals, rotation=25, ha="right", fontsize=8)
ax.set_ylim(0, 1)
ax.grid(True, axis="y", color="#e5e7eb")
for j in range(i + 1, rows_n * cols):
rr = j // cols
cc = j % cols
axes[rr][cc].axis("off")
handles, labels = axes[0][0].get_legend_handles_labels()
fig.legend(handles, labels, loc="upper center", ncol=2, fontsize=9)
fig.suptitle("Discrete Marginals: generated vs real", fontsize=12, color="#111827", y=0.98)
fig.tight_layout(rect=(0, 0, 1, 0.96))
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, format="svg")
plt.close(fig)
def read_csv_rows(path):
p = Path(path)
if not p.exists():
return []
with p.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
return list(reader)
def read_json(path):
p = Path(path)
if not p.exists():
return None
return json.loads(p.read_text(encoding="utf-8"))
def parse_float(s):
if s is None:
return None
ss = str(s).strip()
if ss == "" or ss.lower() == "none" or ss.lower() == "nan":
return None
return float(ss)
def compute_csv_means(path):
p = Path(path)
if not p.exists():
return {}
with p.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
sums = {}
counts = {}
for r in reader:
for k, v in r.items():
x = parse_float(v)
if x is None:
continue
sums[k] = sums.get(k, 0.0) + x
counts[k] = counts.get(k, 0) + 1
means = {}
for k, s in sums.items():
c = counts.get(k, 0)
if c > 0:
means[k] = s / c
return means
def build_mean_profile(ks_rows, cont_stats, generated_csv_path, order="ks_desc", max_features=64):
if not isinstance(cont_stats, dict):
return None
means_real = cont_stats.get("mean")
if not isinstance(means_real, dict):
return None
means_gen = compute_csv_means(generated_csv_path)
if not means_gen:
return None
ks_by_feat = {}
mins = {}
maxs = {}
for r in ks_rows or []:
feat = (r.get("feature") or "").strip()
if not feat:
continue
ks = parse_float(r.get("ks"))
mn = parse_float(r.get("real_min"))
mx = parse_float(r.get("real_max"))
if ks is not None:
ks_by_feat[feat] = ks
if mn is not None:
mins[feat] = mn
if mx is not None:
maxs[feat] = mx
feats = []
real_vals = []
gen_vals = []
ks_vals = []
for feat, mu_real in means_real.items():
if feat not in means_gen:
continue
if feat not in mins or feat not in maxs:
continue
mn = mins[feat]
mx = maxs[feat]
denom = mx - mn
if denom == 0:
continue
mu_gen = means_gen[feat]
feats.append(feat)
real_vals.append(clamp((mu_real - mn) / denom, 0.0, 1.0))
gen_vals.append(clamp((mu_gen - mn) / denom, 0.0, 1.0))
ks_vals.append(ks_by_feat.get(feat))
if not feats:
return None
idx = list(range(len(feats)))
if order == "name":
idx.sort(key=lambda i: feats[i])
else:
idx.sort(key=lambda i: ks_by_feat.get(feats[i], -1.0), reverse=True)
if isinstance(max_features, int) and max_features > 0 and len(idx) > max_features:
idx = idx[:max_features]
sel_feats = [feats[i] for i in idx]
sel_real = [real_vals[i] for i in idx]
sel_gen = [gen_vals[i] for i in idx]
sel_ks = [ks_vals[i] for i in idx]
sel_diff = [abs(a - b) for a, b in zip(sel_real, sel_gen, strict=True)]
def pearsonr(x, y):
if not x or len(x) != len(y):
return None
n = len(x)
mx = sum(x) / n
my = sum(y) / n
vx = sum((xi - mx) ** 2 for xi in x)
vy = sum((yi - my) ** 2 for yi in y)
if vx <= 0 or vy <= 0:
return None
cov = sum((xi - mx) * (yi - my) for xi, yi in zip(x, y, strict=True))
return cov / math.sqrt(vx * vy)
r = pearsonr(sel_real, sel_gen)
mae = (sum(sel_diff) / len(sel_diff)) if sel_diff else None
points = []
for f, xr, yg, ks, d in zip(sel_feats, sel_real, sel_gen, sel_ks, sel_diff, strict=True):
points.append({"feature": f, "x": xr, "y": yg, "ks": ks, "diff": d})
return {
"features": sel_feats,
"real": sel_real,
"gen": sel_gen,
"ks": sel_ks,
"diff": sel_diff,
"points": points,
"stats": {"n": len(sel_feats), "r": r, "mae": mae},
}
def zscores(vals):
if not vals:
return []
m = sum(vals) / len(vals)
v = sum((x - m) * (x - m) for x in vals) / len(vals)
s = math.sqrt(v)
if s == 0:
return [0.0 for _ in vals]
return [(x - m) / s for x in vals]
def panel_svg(bh_rows, ks_rows, shift_rows, hist_rows, filtered_metrics, profile, out_path):
W, H = 1400, 900
margin = 42
gap = 26
panel_w = (W - margin * 2 - gap) / 2
panel_h = (H - margin * 2 - gap) / 2
bg = "#ffffff"
ink = "#111827"
subtle = "#6b7280"
border = "#e5e7eb"
grid = "#eef2f7"
blue = "#3b82f6"
red = "#ef4444"
green = "#10b981"
card_bg = "#ffffff"
card_shadow = "#0f172a"
plot_bg = "#f8fafc"
def panel_rect(x, y):
return (
""
""
).format(
x=x + 2.0,
y=y + 3.0,
x0=x,
y0=y,
w=panel_w,
h=panel_h,
f=card_bg,
b=border,
s=card_shadow,
)
def text(x, y, s, size=12, anchor="start", color=ink, weight="normal"):
return "{t}".format(
x=x,
y=y,
a=anchor,
ff="system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif",
fs=size,
c=color,
w=weight,
t=svg_escape(s),
)
def line(x1, y1, x2, y2, color=border, width=1.0, dash=None, opacity=1.0, cap="round"):
extra = ""
if dash:
extra += " stroke-dasharray='{d}'".format(d=dash)
if opacity != 1.0:
extra += " stroke-opacity='{o}'".format(o=opacity)
return "".format(
x1=x1, y1=y1, x2=x2, y2=y2, c=color, w=width, cap=cap, extra=extra
)
def round_box(x, y, w, h, fill="#ffffff", stroke=border, sw=1.2, rx=12):
return "".format(
x=x, y=y, w=w, h=h, rx=rx, f=fill, s=stroke, sw=sw
)
def arrow(x1, y1, x2, y2, color=ink, width=1.8):
ang = math.atan2(y2 - y1, x2 - x1)
ah = 10.0
aw = 5.0
hx = x2 - ah * math.cos(ang)
hy = y2 - ah * math.sin(ang)
px = aw * math.sin(ang)
py = -aw * math.cos(ang)
p1x, p1y = hx + px, hy + py
p2x, p2y = hx - px, hy - py
return (
""
""
).format(x1=x1, y1=y1, x2=x2, y2=y2, c=color, w=width, p1x=p1x, p1y=p1y, p2x=p2x, p2y=p2y)
parts = []
parts.append(
"")
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text("\n".join(parts), encoding="utf-8")
def panel_matplotlib(bh_rows, ks_rows, shift_rows, hist_rows, filtered_metrics, profile, out_path):
import matplotlib.pyplot as plt
import matplotlib.patches as patches
try:
plt.style.use("seaborn-v0_8-whitegrid")
except Exception:
pass
fig = plt.figure(figsize=(13.6, 8.6))
gs = fig.add_gridspec(2, 2, wspace=0.18, hspace=0.22)
axA = fig.add_subplot(gs[0, 0])
axB = fig.add_subplot(gs[0, 1])
axC = fig.add_subplot(gs[1, 0])
axD = fig.add_subplot(gs[1, 1])
fig.suptitle("Benchmark Overview (HAI Security Dataset)", fontsize=16, y=0.98)
axA.set_title("A Feature-wise Similarity Profile", loc="left", fontsize=12, fontweight="bold")
if isinstance(profile, dict) and profile.get("points"):
pts = profile["points"]
xr = [p["x"] for p in pts]
yg = [p["y"] for p in pts]
diffs = [p.get("diff") or abs(p["x"] - p["y"]) for p in pts]
out_idx = sorted(range(len(pts)), key=lambda i: diffs[i], reverse=True)[:5]
axA.plot([0, 1], [0, 1], linestyle="--", color="#94a3b8", lw=1.8, dashes=(6, 6), label="y=x")
try:
hb = axA.hexbin(
xr,
yg,
gridsize=26,
extent=(0, 1, 0, 1),
cmap="Blues",
mincnt=1,
linewidths=0.0,
alpha=0.95,
)
hb.set_edgecolor("face")
except Exception:
axA.scatter(xr, yg, s=22, color="#3b82f6", alpha=0.35, edgecolors="none")
if out_idx:
axA.scatter(
[xr[i] for i in out_idx],
[yg[i] for i in out_idx],
s=90,
facecolors="none",
edgecolors="#ef4444",
linewidths=2.0,
alpha=0.95,
)
axA.scatter([xr[i] for i in out_idx], [yg[i] for i in out_idx], s=22, color="#ef4444", alpha=0.95, edgecolors="white", linewidths=0.8)
for i in out_idx:
axA.annotate(
pts[i]["feature"],
(xr[i], yg[i]),
textcoords="offset points",
xytext=(12, -10),
ha="left",
va="top",
fontsize=8,
color="#111827",
fontweight="bold",
arrowprops=dict(arrowstyle="-", color="#94a3b8", lw=1.0, linestyle=(0, (3, 4))),
)
axA.set_xlim(0.0, 1.0)
axA.set_ylim(0.0, 1.0)
axA.set_xlabel("Real mean (range-normalized)")
axA.set_ylabel("Generated mean (range-normalized)")
axA.set_aspect("equal", adjustable="box")
axA.grid(True, color="#eef2f7")
axA.legend(loc="lower right", frameon=False, fontsize=9)
stats = profile.get("stats") if isinstance(profile.get("stats"), dict) else {}
n = stats.get("n") if isinstance(stats.get("n"), int) else len(pts)
r = stats.get("r")
mae = stats.get("mae")
s2 = "Pearson r={r:.3f}".format(r=r) if isinstance(r, float) else "Pearson r=NA"
s3 = "MAE={m:.3f}".format(m=mae) if isinstance(mae, float) else "MAE=NA"
axA.text(0.02, 0.98, "n={n} · {s2} · {s3}".format(n=n, s2=s2, s3=s3), transform=axA.transAxes, ha="left", va="top", fontsize=9, color="#6b7280", fontweight="bold")
else:
axA.axis("off")
axA.text(0.5, 0.5, "missing cont_stats.json or generated.csv", ha="center", va="center", fontsize=12, color="#6b7280")
axB.set_title("B Feature-Level Distribution Fidelity", loc="left", fontsize=12, fontweight="bold")
ks_sorted = sorted(
[
{
"feature": r.get("feature", ""),
"ks": parse_float(r.get("ks")),
"gen_frac_at_min": parse_float(r.get("gen_frac_at_min")) or 0.0,
"gen_frac_at_max": parse_float(r.get("gen_frac_at_max")) or 0.0,
}
for r in ks_rows
if parse_float(r.get("ks")) is not None
],
key=lambda x: x["ks"],
reverse=True,
)
top = ks_sorted[:14]
feats = [r["feature"] for r in top][::-1]
vals = [r["ks"] for r in top][::-1]
collapsed = [((r["gen_frac_at_min"] >= 0.98) or (r["gen_frac_at_max"] >= 0.98)) for r in top][::-1]
colors = ["#fb7185" if c else "#0ea5e9" for c in collapsed]
axB.barh(feats, vals, color=colors)
axB.set_xlabel("KS (lower is better)")
axB.set_xlim(0, 1.0)
if isinstance(filtered_metrics, dict) and filtered_metrics.get("dropped_features"):
dropped = ", ".join(d.get("feature", "") for d in filtered_metrics["dropped_features"] if d.get("feature"))
if dropped:
axB.text(0.99, 0.02, "dropped: {d}".format(d=dropped), transform=axB.transAxes, ha="right", va="bottom", fontsize=9, color="#6b7280")
axC.set_title("C Dataset Shift Across Training Files", loc="left", fontsize=12, fontweight="bold")
if shift_rows:
cols = list(shift_rows[0].keys())
else:
cols = []
mean_cols = [c for c in cols if c.startswith("mean_")]
wanted = ["mean_P1_FT01", "mean_P1_LIT01", "mean_P1_PIT01", "mean_P2_CO_rpm", "mean_P3_LIT01", "mean_P4_ST_PT01"]
feats = [c for c in wanted if c in mean_cols]
files = [r.get("file", "") for r in shift_rows]
M = []
for c in feats:
M.append([parse_float(r.get(c)) or 0.0 for r in shift_rows])
if M and files and feats:
Z = list(zip(*[zscores(col) for col in M], strict=True))
im = axC.imshow(Z, aspect="auto", cmap="coolwarm", vmin=-2, vmax=2)
axC.set_yticks(range(len(files)))
axC.set_yticklabels(files)
axC.set_xticks(range(len(feats)))
axC.set_xticklabels([f.replace("mean_", "") for f in feats], rotation=25, ha="right")
axC.set_ylabel("Train file")
axC.set_xlabel("Feature mean z-score")
fig.colorbar(im, ax=axC, fraction=0.046, pad=0.04)
else:
axC.axis("off")
axC.text(0.5, 0.5, "missing data_shift_stats.csv", ha="center", va="center", fontsize=11, color="#6b7280")
axD.set_title("D Robustness Across Seeds", loc="left", fontsize=12, fontweight="bold")
axD.axis("off")
axD.set_xlim(0, 1)
axD.set_ylim(0, 1)
metrics = [("avg_ks", "KS (cont.)"), ("avg_jsd", "JSD (disc.)"), ("avg_lag1_diff", "Abs Δ lag-1")]
bh_rows = sorted(bh_rows, key=lambda r: r.get("seed", 0))
for mi, (k, title) in enumerate(metrics):
vals = [r.get(k) for r in bh_rows if r.get(k) is not None]
if not vals:
continue
m, s = mean_std(vals)
y = 0.78 - mi * 0.22
axD.text(0.04, y, title, fontsize=10, fontweight="bold", va="center")
x0 = 0.42
x1 = 0.96
vmin = min(vals + [m - s])
vmax = max(vals + [m + s])
if vmax == vmin:
vmax = vmin + 1.0
vr = vmax - vmin
vmin -= 0.15 * vr
vmax += 0.15 * vr
def X(v):
return x0 + (v - vmin) * (x1 - x0) / (vmax - vmin)
axD.add_patch(patches.FancyBboxPatch((X(m - s), y - 0.03), max(0.002, X(m + s) - X(m - s)), 0.06, boxstyle="round,pad=0.01,rounding_size=0.02", facecolor="#ef4444", alpha=0.12, edgecolor="none"))
axD.plot([X(m), X(m)], [y - 0.05, y + 0.05], color="#ef4444", lw=2.2)
jit = [-0.03, 0.0, 0.03]
for i, v in enumerate(vals):
axD.scatter([X(v)], [y + jit[i % len(jit)]], s=40, color="#3b82f6", edgecolor="white", linewidth=0.9, zorder=3)
axD.text(0.96, y + 0.07, "mean={m:.4f}±{s:.4f}".format(m=m, s=s), fontsize=9, color="#6b7280", ha="right")
fig.tight_layout(rect=(0, 0, 1, 0.96))
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, format="svg")
plt.close(fig)
def main():
args = parse_args()
hist_path = Path(args.history)
if not hist_path.exists():
raise SystemExit("missing history file: %s" % hist_path)
rows = []
with hist_path.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for r in reader:
rows.append(
{
"run_name": r["run_name"],
"seed": int(r["seed"]),
"avg_ks": float(r["avg_ks"]),
"avg_jsd": float(r["avg_jsd"]),
"avg_lag1_diff": float(r["avg_lag1_diff"]),
}
)
if not rows:
raise SystemExit("no rows in history file: %s" % hist_path)
rows = sorted(rows, key=lambda x: x["seed"])
seeds = [str(r["seed"]) for r in rows]
metrics = [
("avg_ks", "KS (continuous)"),
("avg_jsd", "JSD (discrete)"),
("avg_lag1_diff", "Abs Δ lag-1 autocorr"),
]
if args.out:
out_path = Path(args.out)
else:
if args.figure == "panel":
out_path = Path(__file__).resolve().parent.parent / "figures" / "benchmark_panel.svg"
elif args.figure == "summary":
out_path = Path(__file__).resolve().parent.parent / "figures" / "benchmark_metrics.svg"
else:
out_path = Path(__file__).resolve().parent.parent / "figures" / "ranked_ks.svg"
if args.figure == "summary":
if args.engine in {"auto", "matplotlib"}:
try:
plot_matplotlib(rows, seeds, metrics, out_path)
print("saved", out_path)
return
except Exception:
if args.engine == "matplotlib":
raise
plot_svg(rows, seeds, metrics, out_path)
print("saved", out_path)
return
if args.figure == "ranked_ks":
ranked_rows = read_csv_rows(args.ranked_ks)
ranked_ks_svg(ranked_rows, out_path, top_n=args.ranked_ks_top_n)
print("saved", out_path)
return
if args.figure == "lines":
feats_arg = [f.strip() for f in (args.lines_features or "").split(",") if f.strip()]
features = feats_arg
if not features:
rk_rows = read_csv_rows(args.ranked_ks)
if rk_rows:
sorted_rows = sorted(
[{"feature": (r.get("feature") or "").strip(), "ks": parse_float(r.get("ks"))} for r in rk_rows if (r.get("feature") or "").strip()],
key=lambda x: (x["ks"] if x["ks"] is not None else -1.0),
reverse=True,
)
features = [r["feature"] for r in sorted_rows[:max(1, int(args.lines_top_k))]]
if not features:
features = ["P1_B4002", "P1_PIT02", "P1_FCV02Z", "P1_B3004"]
if not args.out:
out_path = Path(__file__).resolve().parent.parent / "figures" / "lines.svg"
cont_stats = read_json(args.cont_stats)
lines_matplotlib(args.generated, cont_stats, features, out_path, max_rows=args.lines_max_rows, normalize=args.lines_normalize, reference_arg=args.reference, ref_index=args.lines_ref_index)
print("saved", out_path)
return
if args.figure == "cdf_grid":
cont_stats = read_json(args.cont_stats)
feats_arg = [f.strip() for f in (args.cdf_features or "").split(",") if f.strip()]
if feats_arg:
features = feats_arg
else:
features = sorted(list((cont_stats.get("mean") or {}).keys()))
if args.cdf_max_features > 0:
features = features[: args.cdf_max_features]
if not args.out:
out_path = Path(__file__).resolve().parent.parent / "figures" / "cdf_grid.svg"
cdf_grid_matplotlib(args.generated, args.reference, cont_stats, features, out_path, max_rows=max(1000, args.lines_max_rows), bins=args.cdf_bins)
print("saved", out_path)
return
if args.figure == "disc_grid":
split = read_json(args.feature_split)
disc_list = list((split.get("discrete") or [])) if isinstance(split, dict) else []
feats_arg = [f.strip() for f in (args.disc_features or "").split(",") if f.strip()]
if feats_arg:
features = feats_arg
else:
features = sorted(disc_list)
if args.disc_max_features > 0:
features = features[: args.disc_max_features]
if not args.out:
out_path = Path(__file__).resolve().parent.parent / "figures" / "disc_grid.svg"
discrete_grid_matplotlib(args.generated, args.reference, features, out_path, max_rows=max(1000, args.lines_max_rows))
print("saved", out_path)
return
if args.figure == "disc_points":
split = read_json(args.feature_split)
disc_list = list((split.get("discrete") or [])) if isinstance(split, dict) else []
feats_arg = [f.strip() for f in (args.disc_features or "").split(",") if f.strip()]
if feats_arg:
features = feats_arg
else:
features = sorted(disc_list)
if args.disc_max_features > 0:
features = features[: args.disc_max_features]
if not args.out:
out_path = Path(__file__).resolve().parent.parent / "figures" / "disc_points.svg"
discrete_points_matplotlib(args.generated, args.reference, features, out_path, max_rows=max(1000, args.lines_max_rows))
print("saved", out_path)
return
ks_rows = read_csv_rows(args.ks_per_feature)
shift_rows = read_csv_rows(args.data_shift)
mh_rows = read_csv_rows(args.metrics_history)
fm = read_json(args.filtered_metrics)
cont_stats = read_json(args.cont_stats)
profile = build_mean_profile(ks_rows, cont_stats, args.generated, order=args.profile_order, max_features=args.profile_max_features)
bh_rows = [{"seed": r["seed"], "avg_ks": r["avg_ks"], "avg_jsd": r["avg_jsd"], "avg_lag1_diff": r["avg_lag1_diff"]} for r in rows]
if args.engine in {"auto", "matplotlib"}:
try:
panel_matplotlib(bh_rows, ks_rows, shift_rows, mh_rows, fm, profile, out_path)
print("saved", out_path)
return
except Exception:
if args.engine == "matplotlib":
raise
panel_svg(bh_rows, ks_rows, shift_rows, mh_rows, fm, profile, out_path)
print("saved", out_path)
if __name__ == "__main__":
main()