Files
mask-ddpm/example/plot_benchmark.py
Mingzhe Yang d612d8e785 update
2026-02-06 01:54:43 +08:00

957 lines
35 KiB
Python

#!/usr/bin/env python3
import argparse
import csv
import json
import math
from pathlib import Path
def parse_args():
parser = argparse.ArgumentParser(description="Plot benchmark metrics from benchmark_history.csv")
base_dir = Path(__file__).resolve().parent
parser.add_argument(
"--figure",
choices=["panel", "summary"],
default="panel",
help="Figure type: panel (paper-style multi-panel) or summary (seed robustness only).",
)
parser.add_argument(
"--history",
default=str(base_dir / "results" / "benchmark_history.csv"),
help="Path to benchmark_history.csv",
)
parser.add_argument(
"--ks-per-feature",
default=str(base_dir / "results" / "ks_per_feature.csv"),
help="Path to ks_per_feature.csv",
)
parser.add_argument(
"--data-shift",
default=str(base_dir / "results" / "data_shift_stats.csv"),
help="Path to data_shift_stats.csv",
)
parser.add_argument(
"--metrics-history",
default=str(base_dir / "results" / "metrics_history.csv"),
help="Path to metrics_history.csv",
)
parser.add_argument(
"--filtered-metrics",
default=str(base_dir / "results" / "filtered_metrics.json"),
help="Path to filtered_metrics.json (optional).",
)
parser.add_argument(
"--out",
default="",
help="Output SVG path (default depends on --figure).",
)
parser.add_argument(
"--engine",
choices=["auto", "matplotlib", "svg"],
default="auto",
help="Plotting engine: auto prefers matplotlib if available; otherwise uses pure-SVG.",
)
return parser.parse_args()
def mean_std(vals):
m = sum(vals) / len(vals)
if len(vals) == 1:
return m, 0.0
v = sum((x - m) * (x - m) for x in vals) / (len(vals) - 1)
return m, math.sqrt(v)
def svg_escape(s):
return (
str(s)
.replace("&", "&")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace('"', "&quot;")
.replace("'", "&apos;")
)
def clamp(v, lo, hi):
if v < lo:
return lo
if v > hi:
return hi
return v
def lerp(a, b, t):
return a + (b - a) * t
def hex_to_rgb(h):
h = h.lstrip("#")
return int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
def rgb_to_hex(r, g, b):
return "#{:02x}{:02x}{:02x}".format(int(clamp(r, 0, 255)), int(clamp(g, 0, 255)), int(clamp(b, 0, 255)))
def diverging_color(v, vmin=-2.0, vmax=2.0, cold="#2563eb", hot="#ef4444", mid="#ffffff"):
v = clamp(v, vmin, vmax)
if v >= 0:
t = 0.0 if vmax == 0 else v / vmax
r0, g0, b0 = hex_to_rgb(mid)
r1, g1, b1 = hex_to_rgb(hot)
return rgb_to_hex(lerp(r0, r1, t), lerp(g0, g1, t), lerp(b0, b1, t))
t = 0.0 if vmin == 0 else (-v) / (-vmin)
r0, g0, b0 = hex_to_rgb(mid)
r1, g1, b1 = hex_to_rgb(cold)
return rgb_to_hex(lerp(r0, r1, t), lerp(g0, g1, t), lerp(b0, b1, t))
def plot_matplotlib(rows, seeds, metrics, out_path):
import matplotlib.pyplot as plt
try:
plt.style.use("seaborn-v0_8-whitegrid")
except Exception:
pass
fig, axes = plt.subplots(nrows=len(metrics), ncols=1, figsize=(8.6, 4.6), sharex=False)
if len(metrics) == 1:
axes = [axes]
point_color = "#3b82f6"
band_color = "#ef4444"
grid_color = "#e5e7eb"
axis_color = "#111827"
for ax, (key, title) in zip(axes, metrics, strict=True):
vals = [r[key] for r in rows]
m, s = mean_std(vals)
vmin = min(vals + [m - s])
vmax = max(vals + [m + s])
if vmax == vmin:
vmax = vmin + 1.0
vr = vmax - vmin
vmin -= 0.20 * vr
vmax += 0.20 * vr
y0 = 0.0
jitter = [-0.08, 0.0, 0.08]
ys = [(y0 + jitter[i % len(jitter)]) for i in range(len(vals))]
ax.axvspan(m - s, m + s, color=band_color, alpha=0.10, linewidth=0)
ax.axvline(m, color=band_color, linewidth=2.2)
ax.scatter(vals, ys, s=46, color=point_color, edgecolors="white", linewidths=1.0, zorder=3)
for x, y, seed in zip(vals, ys, seeds, strict=True):
ax.annotate(
str(seed),
(x, y),
textcoords="offset points",
xytext=(0, 10),
ha="center",
va="bottom",
fontsize=8,
color=axis_color,
)
ax.set_title(title, loc="left", fontsize=11, color=axis_color, pad=8)
ax.set_yticks([])
ax.set_ylim(-0.35, 0.35)
ax.set_xlim(vmin, vmax)
ax.grid(True, axis="x", color=grid_color)
ax.grid(False, axis="y")
ax.text(
0.99,
0.80,
"mean={m:.4f} ± {s:.4f}".format(m=m, s=s),
transform=ax.transAxes,
ha="right",
va="center",
fontsize=9,
color="#374151",
)
fig.suptitle("Benchmark Metrics (3 seeds) · lower is better", fontsize=12, color=axis_color, y=0.98)
fig.tight_layout(rect=(0, 0, 1, 0.95))
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, format="svg")
plt.close(fig)
def plot_svg(rows, seeds, metrics, out_path):
W, H = 980, 440
pad_l, pad_r, pad_t, pad_b = 200, 30, 74, 36
row_gap = 52
row_h = (H - pad_t - pad_b - row_gap * (len(metrics) - 1)) / len(metrics)
bg = "#ffffff"
axis = "#2b2b2b"
grid = "#e9e9e9"
band = "#d62728"
band_fill = "#d62728"
point = "#1f77b4"
text = "#111111"
subtle = "#666666"
parts = []
parts.append(
"<svg xmlns='http://www.w3.org/2000/svg' width='{w}' height='{h}' viewBox='0 0 {w} {h}'>".format(
w=W, h=H
)
)
parts.append("<rect x='0' y='0' width='{w}' height='{h}' fill='{bg}'/>".format(w=W, h=H, bg=bg))
parts.append(
"<text x='{x:.1f}' y='28' text-anchor='middle' font-family='Arial' font-size='16' fill='{c}'>Benchmark Metrics (3 seeds)</text>".format(
x=W / 2, c=text
)
)
parts.append(
"<text x='{x:.1f}' y='48' text-anchor='middle' font-family='Arial' font-size='11' fill='{c}'>line: mean · band: ±1 std · dots: runs · lower is better</text>".format(
x=W / 2, c=subtle
)
)
parts.append(
"<text x='{x:.1f}' y='64' text-anchor='middle' font-family='Arial' font-size='10' fill='{c}'>seeds: {s}</text>".format(
x=W / 2, c=subtle, s=svg_escape(", ".join(seeds))
)
)
plot_x0 = pad_l
plot_x1 = W - pad_r
for mi, (key, title) in enumerate(metrics):
y0 = pad_t + mi * (row_h + row_gap)
y1 = y0 + row_h
yc = (y0 + y1) / 2
vals = [r[key] for r in rows]
m, s = mean_std(vals)
vmin = min(vals + [m - s])
vmax = max(vals + [m + s])
if vmax == vmin:
vmax = vmin + 1.0
vr = vmax - vmin
vmin -= 0.20 * vr
vmax += 0.20 * vr
def X(v):
return plot_x0 + (v - vmin) * (plot_x1 - plot_x0) / (vmax - vmin)
if mi % 2 == 1:
parts.append(
"<rect x='{x}' y='{y}' width='{w}' height='{h}' fill='#fafafa'/>".format(
x=0, y=y0 - 8, w=W, h=row_h + 16
)
)
parts.append(
"<text x='{x}' y='{y:.1f}' text-anchor='end' font-family='Arial' font-size='12' fill='{c}'>{t}</text>".format(
x=pad_l - 12, y=yc + 4, c=text, t=svg_escape(title)
)
)
for k in range(6):
xx = plot_x0 + k * (plot_x1 - plot_x0) / 5
parts.append(
"<line x1='{x:.1f}' y1='{y0:.1f}' x2='{x:.1f}' y2='{y1:.1f}' stroke='{c}' stroke-width='1'/>".format(
x=xx, y0=y0 + 4, y1=y1 - 4, c=grid
)
)
val = vmin + k * (vmax - vmin) / 5
parts.append(
"<text x='{x:.1f}' y='{y:.1f}' text-anchor='middle' font-family='Arial' font-size='10' fill='{c}'>{v:.4f}</text>".format(
x=xx, y=y1 + 28, c=subtle, v=val
)
)
parts.append(
"<line x1='{x0}' y1='{y:.1f}' x2='{x1}' y2='{y:.1f}' stroke='{c}' stroke-width='1.2'/>".format(
x0=plot_x0, x1=plot_x1, y=yc, c=axis
)
)
x_lo = X(m - s)
x_hi = X(m + s)
parts.append(
"<rect x='{x:.1f}' y='{y:.1f}' width='{w:.1f}' height='{h:.1f}' fill='{c}' fill-opacity='0.12' stroke='none'/>".format(
x=min(x_lo, x_hi), y=yc - 14, w=abs(x_hi - x_lo), h=28, c=band_fill
)
)
xm = X(m)
parts.append(
"<line x1='{x:.1f}' y1='{y0:.1f}' x2='{x:.1f}' y2='{y1:.1f}' stroke='{c}' stroke-width='2'/>".format(
x=xm, y0=yc - 16, y1=yc + 16, c=band
)
)
parts.append(
"<text x='{x:.1f}' y='{y:.1f}' text-anchor='start' font-family='Arial' font-size='10' fill='{c}'>mean={m:.4f}±{s:.4f}</text>".format(
x=xm + 6, y=yc - 18, c=subtle, m=m, s=s
)
)
for i, r in enumerate(rows):
jitter = ((i * 37) % 11 - 5) * 0.9
xx = X(r[key])
yy = yc + jitter
parts.append("<circle cx='{x:.1f}' cy='{y:.1f}' r='5' fill='{c}'/>".format(x=xx, y=yy, c=point))
parts.append("</svg>")
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text("\n".join(parts), encoding="utf-8")
def read_csv_rows(path):
p = Path(path)
if not p.exists():
return []
with p.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
return list(reader)
def read_json(path):
p = Path(path)
if not p.exists():
return None
return json.loads(p.read_text(encoding="utf-8"))
def parse_float(s):
if s is None:
return None
ss = str(s).strip()
if ss == "" or ss.lower() == "none" or ss.lower() == "nan":
return None
return float(ss)
def zscores(vals):
if not vals:
return []
m = sum(vals) / len(vals)
v = sum((x - m) * (x - m) for x in vals) / len(vals)
s = math.sqrt(v)
if s == 0:
return [0.0 for _ in vals]
return [(x - m) / s for x in vals]
def panel_svg(bh_rows, ks_rows, shift_rows, hist_rows, filtered_metrics, out_path):
W, H = 1400, 900
margin = 42
gap = 26
panel_w = (W - margin * 2 - gap) / 2
panel_h = (H - margin * 2 - gap) / 2
bg = "#ffffff"
ink = "#111827"
subtle = "#6b7280"
border = "#e5e7eb"
grid = "#eef2f7"
blue = "#3b82f6"
red = "#ef4444"
green = "#10b981"
def panel_rect(x, y):
return "<rect x='{x:.1f}' y='{y:.1f}' width='{w:.1f}' height='{h:.1f}' rx='16' ry='16' fill='#ffffff' stroke='{b}' stroke-width='1.2'/>".format(
x=x, y=y, w=panel_w, h=panel_h, b=border
)
def text(x, y, s, size=12, anchor="start", color=ink, weight="normal"):
return "<text x='{x:.1f}' y='{y:.1f}' text-anchor='{a}' font-family='Arial' font-size='{fs}' font-weight='{w}' fill='{c}'>{t}</text>".format(
x=x, y=y, a=anchor, fs=size, c=color, w=weight, t=svg_escape(s)
)
def line(x1, y1, x2, y2, color=border, width=1.0, dash=None, opacity=1.0, cap="round"):
extra = ""
if dash:
extra += " stroke-dasharray='{d}'".format(d=dash)
if opacity != 1.0:
extra += " stroke-opacity='{o}'".format(o=opacity)
return "<line x1='{x1:.1f}' y1='{y1:.1f}' x2='{x2:.1f}' y2='{y2:.1f}' stroke='{c}' stroke-width='{w}' stroke-linecap='{cap}'{extra}/>".format(
x1=x1, y1=y1, x2=x2, y2=y2, c=color, w=width, cap=cap, extra=extra
)
def round_box(x, y, w, h, fill="#ffffff", stroke=border, sw=1.2, rx=12):
return "<rect x='{x:.1f}' y='{y:.1f}' width='{w:.1f}' height='{h:.1f}' rx='{rx}' ry='{rx}' fill='{f}' stroke='{s}' stroke-width='{sw}'/>".format(
x=x, y=y, w=w, h=h, rx=rx, f=fill, s=stroke, sw=sw
)
def arrow(x1, y1, x2, y2, color=ink, width=1.8):
ang = math.atan2(y2 - y1, x2 - x1)
ah = 10.0
aw = 5.0
hx = x2 - ah * math.cos(ang)
hy = y2 - ah * math.sin(ang)
px = aw * math.sin(ang)
py = -aw * math.cos(ang)
p1x, p1y = hx + px, hy + py
p2x, p2y = hx - px, hy - py
return (
"<path d='M {x1:.1f} {y1:.1f} L {x2:.1f} {y2:.1f}' fill='none' stroke='{c}' stroke-width='{w}' stroke-linecap='round'/>"
"<path d='M {x2:.1f} {y2:.1f} L {p1x:.1f} {p1y:.1f} L {p2x:.1f} {p2y:.1f} Z' fill='{c}'/>"
).format(x1=x1, y1=y1, x2=x2, y2=y2, c=color, w=width, p1x=p1x, p1y=p1y, p2x=p2x, p2y=p2y)
parts = []
parts.append(
"<svg xmlns='http://www.w3.org/2000/svg' width='{w}' height='{h}' viewBox='0 0 {w} {h}'>".format(w=W, h=H)
)
parts.append("<rect x='0' y='0' width='{w}' height='{h}' fill='{bg}'/>".format(w=W, h=H, bg=bg))
parts.append(text(W / 2, 32, "Benchmark Overview (HAI Security Dataset)", size=18, anchor="middle", weight="bold"))
parts.append(
text(
W / 2,
54,
"A: workflow · B: per-feature KS · C: train-file mean shift · D: seed robustness and metric history",
size=11,
anchor="middle",
color=subtle,
)
)
xA, yA = margin, margin + 36
xB, yB = margin + panel_w + gap, margin + 36
xC, yC = margin, margin + 36 + panel_h + gap
xD, yD = margin + panel_w + gap, margin + 36 + panel_h + gap
parts.append(panel_rect(xA, yA))
parts.append(panel_rect(xB, yB))
parts.append(panel_rect(xC, yC))
parts.append(panel_rect(xD, yD))
def panel_label(x, y, letter, title):
parts.append(text(x + 18, y + 28, letter, size=16, weight="bold"))
parts.append(text(x + 44, y + 28, title, size=14, weight="bold"))
panel_label(xA, yA, "A", "Typed Hybrid Generation")
panel_label(xB, yB, "B", "Feature-Level Distribution Fidelity")
panel_label(xC, yC, "C", "Dataset Shift Across Training Files")
panel_label(xD, yD, "D", "Robustness Across Seeds")
ax0 = xA + 22
ay0 = yA + 56
aw0 = panel_w - 44
ah0 = panel_h - 78
box_h = 56
box_w = (aw0 - 3 * 26) / 4
by = ay0 + (ah0 - box_h) / 2 - 14
boxes = [
("HAI windows\n(L=96)", "#f8fafc"),
("Typed\ndecomposition", "#f8fafc"),
("Hybrid\ngenerator", "#f8fafc"),
("Synthetic\nwindows", "#f8fafc"),
]
for i, (lbl, fill) in enumerate(boxes):
bx = ax0 + i * (box_w + 26)
parts.append(round_box(bx, by, box_w, box_h, fill=fill))
for j, line_txt in enumerate(lbl.split("\n")):
parts.append(text(bx + box_w / 2, by + 22 + j * 16, line_txt, size=11, anchor="middle", weight="bold" if j == 0 else "normal"))
if i < len(boxes) - 1:
parts.append(arrow(bx + box_w, by + box_h / 2, bx + box_w + 26, by + box_h / 2, color=subtle, width=1.6))
hx = ax0 + 2 * (box_w + 26)
hy = by + box_h + 18
parts.append(text(hx + 6, hy - 6, "Type-aware routes", size=10, color=subtle, weight="bold"))
inner_w = box_w
inner_gap = 10
inner_h = 32
inner_y = hy
inner_colors = [("#e0f2fe", blue, "Trend (det.)"), ("#fee2e2", red, "Residual (DDPM)"), ("#dcfce7", green, "Discrete head")]
for k, (fill, stroke, name) in enumerate(inner_colors):
iy = inner_y + k * (inner_h + inner_gap)
parts.append(round_box(hx, iy, inner_w, inner_h, fill=fill, stroke=stroke, sw=1.4, rx=10))
parts.append(text(hx + 10, iy + 20, name, size=10, color=ink, weight="bold"))
parts.append(arrow(hx + inner_w / 2, by + box_h, hx + inner_w / 2, inner_y, color=subtle, width=1.4))
parts.append(
text(
xA + 22,
yA + panel_h - 18,
"Separation aligns metrics with data types: KS (continuous), JSD (discrete), lag-1 (temporal).",
size=10,
color=subtle,
)
)
bx0 = xB + 22
by0 = yB + 62
bw0 = panel_w - 44
bh0 = panel_h - 86
ks_sorted = sorted(
[
{
"feature": r.get("feature", ""),
"ks": parse_float(r.get("ks")),
"gen_frac_at_min": parse_float(r.get("gen_frac_at_min")),
"gen_frac_at_max": parse_float(r.get("gen_frac_at_max")),
}
for r in ks_rows
if parse_float(r.get("ks")) is not None
],
key=lambda x: x["ks"],
reverse=True,
)
top_n = 14 if len(ks_sorted) >= 14 else len(ks_sorted)
top = ks_sorted[:top_n]
dropped = []
if isinstance(filtered_metrics, dict):
for d in filtered_metrics.get("dropped_features", []) or []:
feat = d.get("feature")
if feat:
dropped.append(feat)
parts.append(text(bx0, by0 - 8, "Top-{n} KS outliers (lower is better)".format(n=top_n), size=11, color=subtle))
if dropped:
parts.append(text(bx0 + bw0, by0 - 8, "dropped: {d}".format(d=", ".join(dropped)), size=10, anchor="end", color=subtle))
chart_y0 = by0 + 16
chart_h = bh0 - 32
label_w = 180
x0 = bx0 + label_w
x1 = bx0 + bw0 - 16
row_h = chart_h / max(1, top_n)
for t in range(6):
xx = x0 + (x1 - x0) * (t / 5)
parts.append(line(xx, chart_y0, xx, chart_y0 + chart_h, color=grid, width=1.0))
parts.append(text(xx, chart_y0 + chart_h + 18, "{:.1f}".format(t / 5), size=9, anchor="middle", color=subtle))
parts.append(line(x0, chart_y0 + chart_h, x1, chart_y0 + chart_h, color=border, width=1.2, cap="butt"))
for i, r in enumerate(top):
fy = chart_y0 + i * row_h + row_h / 2
feat = r["feature"]
ks = r["ks"]
gmin = r["gen_frac_at_min"] or 0.0
gmax = r["gen_frac_at_max"] or 0.0
collapsed = (gmin >= 0.98) or (gmax >= 0.98)
bar_color = "#0ea5e9" if not collapsed else "#fb7185"
parts.append(text(x0 - 10, fy + 4, feat, size=9, anchor="end", color=ink))
parts.append(text(x1 + 8, fy + 4, "{:.3f}".format(ks), size=9, anchor="start", color=subtle))
w = (x1 - x0) * clamp(ks, 0.0, 1.0)
parts.append(
"<rect x='{x:.1f}' y='{y:.1f}' width='{w:.1f}' height='{h:.1f}' rx='6' ry='6' fill='{c}' fill-opacity='0.85'/>".format(
x=x0, y=fy - row_h * 0.34, w=w, h=row_h * 0.68, c=bar_color
)
)
if collapsed:
parts.append(
text(
x0 + min(w, (x1 - x0) - 10),
fy + 4,
"collapse",
size=8,
anchor="end",
color="#7f1d1d",
weight="bold",
)
)
cx0 = xC + 22
cy0 = yC + 64
cw0 = panel_w - 44
ch0 = panel_h - 88
if shift_rows:
cols = list(shift_rows[0].keys())
else:
cols = []
mean_cols = [c for c in cols if c.startswith("mean_")]
wanted = ["mean_P1_FT01", "mean_P1_LIT01", "mean_P1_PIT01", "mean_P2_CO_rpm", "mean_P3_LIT01", "mean_P4_ST_PT01"]
feats = [c for c in wanted if c in mean_cols]
files = [r.get("file", "") for r in shift_rows]
sample_rows = [parse_float(r.get("sample_rows")) for r in shift_rows]
feat_vals = {c: [parse_float(r.get(c)) or 0.0 for r in shift_rows] for c in feats}
feat_z = {c: zscores(vs) for c, vs in feat_vals.items()}
parts.append(text(cx0, cy0 - 10, "Mean shift (z-score) across train files", size=11, color=subtle))
if files:
parts.append(text(cx0 + cw0, cy0 - 10, "rows: {r}".format(r=", ".join(str(int(x)) for x in sample_rows if x is not None)), size=10, anchor="end", color=subtle))
heat_x0 = cx0 + 160
heat_y0 = cy0 + 10
heat_w = cw0 - 180
heat_h = ch0 - 44
n_rows = max(1, len(files))
n_cols = max(1, len(feats))
cell_w = heat_w / n_cols
cell_h = heat_h / n_rows
for j, c in enumerate(feats):
label = c.replace("mean_", "")
parts.append(text(heat_x0 + j * cell_w + cell_w / 2, heat_y0 - 8, label, size=9, anchor="middle", color=ink, weight="bold"))
for i, f in enumerate(files):
fy = heat_y0 + i * cell_h + cell_h / 2
parts.append(text(cx0 + 140, fy + 4, f, size=9, anchor="end", color=ink))
for i in range(n_rows + 1):
yy = heat_y0 + i * cell_h
parts.append(line(heat_x0, yy, heat_x0 + heat_w, yy, color=border, width=1.0, cap="butt"))
for j in range(n_cols + 1):
xx = heat_x0 + j * cell_w
parts.append(line(xx, heat_y0, xx, heat_y0 + heat_h, color=border, width=1.0, cap="butt"))
for i in range(len(files)):
for j, c in enumerate(feats):
z = feat_z[c][i] if i < len(feat_z[c]) else 0.0
fill = diverging_color(z, vmin=-2.0, vmax=2.0)
parts.append(
"<rect x='{x:.1f}' y='{y:.1f}' width='{w:.1f}' height='{h:.1f}' fill='{c}'/>".format(
x=heat_x0 + j * cell_w + 0.6, y=heat_y0 + i * cell_h + 0.6, w=cell_w - 1.2, h=cell_h - 1.2, c=fill
)
)
parts.append(text(heat_x0 + j * cell_w + cell_w / 2, heat_y0 + i * cell_h + cell_h / 2 + 4, "{:+.2f}".format(z), size=9, anchor="middle", color=ink))
lx0 = heat_x0
ly0 = heat_y0 + heat_h + 18
parts.append(text(cx0 + 140, ly0 + 4, "z", size=9, anchor="end", color=subtle))
grad_w = 240
for k in range(25):
t = k / 24
z = lerp(-2.0, 2.0, t)
fill = diverging_color(z)
parts.append(
"<rect x='{x:.1f}' y='{y:.1f}' width='{w:.1f}' height='10' fill='{c}'/>".format(
x=lx0 + k * (grad_w / 25), y=ly0 - 6, w=(grad_w / 25) + 0.2, c=fill
)
)
parts.append(text(lx0, ly0 + 16, "-2", size=9, color=subtle))
parts.append(text(lx0 + grad_w / 2, ly0 + 16, "0", size=9, anchor="middle", color=subtle))
parts.append(text(lx0 + grad_w, ly0 + 16, "+2", size=9, anchor="end", color=subtle))
dx0 = xD + 22
dy0 = yD + 62
dw0 = panel_w - 44
dh0 = panel_h - 86
metrics = [
("avg_ks", "KS (cont.)"),
("avg_jsd", "JSD (disc.)"),
("avg_lag1_diff", "Abs Δ lag-1"),
]
bh_rows = sorted(bh_rows, key=lambda r: r.get("seed", 0))
seeds = [str(r.get("seed", "")) for r in bh_rows]
parts.append(text(dx0, dy0 - 10, "Seed robustness (mean ± 1 std; dots: seeds)", size=11, color=subtle))
if seeds:
parts.append(text(dx0 + dw0, dy0 - 10, "seeds: {s}".format(s=", ".join(seeds)), size=10, anchor="end", color=subtle))
spark_h = 48
spark_gap = 10
forest_y0 = dy0 + spark_h * 3 + spark_gap * 2 + 18
forest_h = dh0 - (spark_h * 3 + spark_gap * 2 + 28)
hist_clean = []
for r in hist_rows:
ks = parse_float(r.get("avg_ks"))
jsd = parse_float(r.get("avg_jsd"))
lag = parse_float(r.get("avg_lag1_diff"))
if ks is None or jsd is None or lag is None:
continue
hist_clean.append({"avg_ks": ks, "avg_jsd": jsd, "avg_lag1_diff": lag})
if hist_clean:
for mi, (k, title) in enumerate(metrics):
y0 = dy0 + mi * (spark_h + spark_gap)
y1 = y0 + spark_h
parts.append(round_box(dx0, y0, dw0, spark_h, fill="#f9fafb", stroke=border, sw=1.0, rx=10))
parts.append(text(dx0 + 10, y0 + 18, title + " history", size=10, color=subtle, weight="bold"))
vals = [r[k] for r in hist_clean]
vmin = min(vals)
vmax = max(vals)
if vmax == vmin:
vmax = vmin + 1.0
px0 = dx0 + 130
px1 = dx0 + dw0 - 12
py0 = y0 + 30
py1 = y1 - 10
parts.append(line(px0, py1, px1, py1, color=border, width=1.0, cap="butt"))
parts.append(line(px0, py0, px0, py1, color=border, width=1.0, cap="butt"))
pts = []
n = len(vals)
for i, v in enumerate(vals):
x = px0 + (px1 - px0) * (i / max(1, n - 1))
y = py1 - (v - vmin) * (py1 - py0) / (vmax - vmin)
pts.append((x, y))
d = "M " + " L ".join("{:.1f} {:.1f}".format(x, y) for x, y in pts)
parts.append("<path d='{d}' fill='none' stroke='{c}' stroke-width='2'/>".format(d=d, c=blue if mi == 0 else (red if mi == 1 else green)))
parts.append(text(dx0 + dw0 - 12, y0 + 18, "{:.3f}".format(vals[-1]), size=10, anchor="end", color=ink, weight="bold"))
fx0 = dx0
fy0 = forest_y0
fw0 = dw0
fh0 = forest_h
x0 = fx0 + 190
x1 = fx0 + fw0 - 18
for t in range(6):
xx = x0 + (x1 - x0) * (t / 5)
parts.append(line(xx, fy0, xx, fy0 + fh0, color=grid, width=1.0))
for mi, (key, title) in enumerate(metrics):
y0 = fy0 + mi * (fh0 / 3)
y1 = fy0 + (mi + 1) * (fh0 / 3)
yc = (y0 + y1) / 2
vals = [r.get(key) for r in bh_rows if r.get(key) is not None]
if not vals:
continue
m, s = mean_std(vals)
vmin = min(vals + [m - s])
vmax = max(vals + [m + s])
if vmax == vmin:
vmax = vmin + 1.0
vr = vmax - vmin
vmin -= 0.15 * vr
vmax += 0.15 * vr
def X(v):
return x0 + (v - vmin) * (x1 - x0) / (vmax - vmin)
parts.append(text(x0 - 14, yc + 4, title, size=11, anchor="end", color=ink, weight="bold"))
parts.append(line(x0, yc, x1, yc, color=border, width=1.2, cap="butt"))
parts.append(
"<rect x='{x:.1f}' y='{y:.1f}' width='{w:.1f}' height='20' rx='10' ry='10' fill='{c}' fill-opacity='0.10'/>".format(
x=X(m - s), y=yc - 10, w=max(1.0, X(m + s) - X(m - s)), c=red
)
)
parts.append(line(X(m), yc - 14, X(m), yc + 14, color=red, width=2.4))
parts.append(text(x1, yc - 16, "mean={m:.4f}±{s:.4f}".format(m=m, s=s), size=9, anchor="end", color=subtle))
for i, v in enumerate(vals):
jitter = ((i * 37) % 9 - 4) * 1.2
parts.append("<circle cx='{x:.1f}' cy='{y:.1f}' r='5' fill='{c}'/>".format(x=X(v), y=yc + jitter, c=blue))
parts.append("</svg>")
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text("\n".join(parts), encoding="utf-8")
def panel_matplotlib(bh_rows, ks_rows, shift_rows, hist_rows, filtered_metrics, out_path):
import matplotlib.pyplot as plt
import matplotlib.patches as patches
try:
plt.style.use("seaborn-v0_8-whitegrid")
except Exception:
pass
fig = plt.figure(figsize=(13.6, 8.6))
gs = fig.add_gridspec(2, 2, wspace=0.18, hspace=0.22)
axA = fig.add_subplot(gs[0, 0])
axB = fig.add_subplot(gs[0, 1])
axC = fig.add_subplot(gs[1, 0])
axD = fig.add_subplot(gs[1, 1])
fig.suptitle("Benchmark Overview (HAI Security Dataset)", fontsize=16, y=0.98)
axA.set_title("A Typed Hybrid Generation", loc="left", fontsize=12, fontweight="bold")
axA.axis("off")
axA.set_xlim(0, 1)
axA.set_ylim(0, 1)
box_y = 0.55
box_w = 0.18
box_h = 0.16
x_positions = [0.06, 0.30, 0.54, 0.78]
labels = ["HAI windows\n(L=96)", "Typed\ndecomposition", "Hybrid\ngenerator", "Synthetic\nwindows"]
for x, lbl in zip(x_positions, labels, strict=True):
axA.add_patch(patches.FancyBboxPatch((x, box_y), box_w, box_h, boxstyle="round,pad=0.02,rounding_size=0.02", facecolor="#f8fafc", edgecolor="#e5e7eb"))
axA.text(x + box_w / 2, box_y + box_h / 2, lbl, ha="center", va="center", fontsize=10, fontweight="bold")
for i in range(3):
x1 = x_positions[i] + box_w
x2 = x_positions[i + 1]
axA.annotate("", xy=(x2, box_y + box_h / 2), xytext=(x1, box_y + box_h / 2), arrowprops=dict(arrowstyle="-|>", lw=1.4, color="#6b7280"))
hx = x_positions[2]
hy = 0.20
axA.text(hx, hy + 0.27, "Type-aware routes", fontsize=9, color="#6b7280", fontweight="bold")
inner = [("Trend (det.)", "#e0f2fe", "#3b82f6"), ("Residual (DDPM)", "#fee2e2", "#ef4444"), ("Discrete head", "#dcfce7", "#10b981")]
for k, (name, fc, ec) in enumerate(inner):
y = hy + 0.18 - k * 0.11
axA.add_patch(patches.FancyBboxPatch((hx, y), box_w, 0.08, boxstyle="round,pad=0.02,rounding_size=0.02", facecolor=fc, edgecolor=ec, lw=1.2))
axA.text(hx + 0.01, y + 0.04, name, ha="left", va="center", fontsize=9, fontweight="bold")
axA.annotate("", xy=(hx + box_w / 2, hy + 0.20), xytext=(hx + box_w / 2, box_y), arrowprops=dict(arrowstyle="-|>", lw=1.2, color="#6b7280"))
axA.text(0.06, 0.06, "Metrics align with types: KS (continuous), JSD (discrete), lag-1 (temporal).", fontsize=9, color="#6b7280")
axB.set_title("B Feature-Level Distribution Fidelity", loc="left", fontsize=12, fontweight="bold")
ks_sorted = sorted(
[
{
"feature": r.get("feature", ""),
"ks": parse_float(r.get("ks")),
"gen_frac_at_min": parse_float(r.get("gen_frac_at_min")) or 0.0,
"gen_frac_at_max": parse_float(r.get("gen_frac_at_max")) or 0.0,
}
for r in ks_rows
if parse_float(r.get("ks")) is not None
],
key=lambda x: x["ks"],
reverse=True,
)
top = ks_sorted[:14]
feats = [r["feature"] for r in top][::-1]
vals = [r["ks"] for r in top][::-1]
collapsed = [((r["gen_frac_at_min"] >= 0.98) or (r["gen_frac_at_max"] >= 0.98)) for r in top][::-1]
colors = ["#fb7185" if c else "#0ea5e9" for c in collapsed]
axB.barh(feats, vals, color=colors)
axB.set_xlabel("KS (lower is better)")
axB.set_xlim(0, 1.0)
if isinstance(filtered_metrics, dict) and filtered_metrics.get("dropped_features"):
dropped = ", ".join(d.get("feature", "") for d in filtered_metrics["dropped_features"] if d.get("feature"))
if dropped:
axB.text(0.99, 0.02, "dropped: {d}".format(d=dropped), transform=axB.transAxes, ha="right", va="bottom", fontsize=9, color="#6b7280")
axC.set_title("C Dataset Shift Across Training Files", loc="left", fontsize=12, fontweight="bold")
if shift_rows:
cols = list(shift_rows[0].keys())
else:
cols = []
mean_cols = [c for c in cols if c.startswith("mean_")]
wanted = ["mean_P1_FT01", "mean_P1_LIT01", "mean_P1_PIT01", "mean_P2_CO_rpm", "mean_P3_LIT01", "mean_P4_ST_PT01"]
feats = [c for c in wanted if c in mean_cols]
files = [r.get("file", "") for r in shift_rows]
M = []
for c in feats:
M.append([parse_float(r.get(c)) or 0.0 for r in shift_rows])
if M and files and feats:
Z = list(zip(*[zscores(col) for col in M], strict=True))
im = axC.imshow(Z, aspect="auto", cmap="coolwarm", vmin=-2, vmax=2)
axC.set_yticks(range(len(files)))
axC.set_yticklabels(files)
axC.set_xticks(range(len(feats)))
axC.set_xticklabels([f.replace("mean_", "") for f in feats], rotation=25, ha="right")
axC.set_ylabel("Train file")
axC.set_xlabel("Feature mean z-score")
fig.colorbar(im, ax=axC, fraction=0.046, pad=0.04)
else:
axC.axis("off")
axC.text(0.5, 0.5, "missing data_shift_stats.csv", ha="center", va="center", fontsize=11, color="#6b7280")
axD.set_title("D Robustness Across Seeds", loc="left", fontsize=12, fontweight="bold")
axD.axis("off")
axD.set_xlim(0, 1)
axD.set_ylim(0, 1)
metrics = [("avg_ks", "KS (cont.)"), ("avg_jsd", "JSD (disc.)"), ("avg_lag1_diff", "Abs Δ lag-1")]
bh_rows = sorted(bh_rows, key=lambda r: r.get("seed", 0))
for mi, (k, title) in enumerate(metrics):
vals = [r.get(k) for r in bh_rows if r.get(k) is not None]
if not vals:
continue
m, s = mean_std(vals)
y = 0.78 - mi * 0.22
axD.text(0.04, y, title, fontsize=10, fontweight="bold", va="center")
x0 = 0.42
x1 = 0.96
vmin = min(vals + [m - s])
vmax = max(vals + [m + s])
if vmax == vmin:
vmax = vmin + 1.0
vr = vmax - vmin
vmin -= 0.15 * vr
vmax += 0.15 * vr
def X(v):
return x0 + (v - vmin) * (x1 - x0) / (vmax - vmin)
axD.add_patch(patches.FancyBboxPatch((X(m - s), y - 0.03), max(0.002, X(m + s) - X(m - s)), 0.06, boxstyle="round,pad=0.01,rounding_size=0.02", facecolor="#ef4444", alpha=0.12, edgecolor="none"))
axD.plot([X(m), X(m)], [y - 0.05, y + 0.05], color="#ef4444", lw=2.2)
jit = [-0.03, 0.0, 0.03]
for i, v in enumerate(vals):
axD.scatter([X(v)], [y + jit[i % len(jit)]], s=40, color="#3b82f6", edgecolor="white", linewidth=0.9, zorder=3)
axD.text(0.96, y + 0.07, "mean={m:.4f}±{s:.4f}".format(m=m, s=s), fontsize=9, color="#6b7280", ha="right")
fig.tight_layout(rect=(0, 0, 1, 0.96))
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, format="svg")
plt.close(fig)
def main():
args = parse_args()
hist_path = Path(args.history)
if not hist_path.exists():
raise SystemExit("missing history file: %s" % hist_path)
rows = []
with hist_path.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for r in reader:
rows.append(
{
"run_name": r["run_name"],
"seed": int(r["seed"]),
"avg_ks": float(r["avg_ks"]),
"avg_jsd": float(r["avg_jsd"]),
"avg_lag1_diff": float(r["avg_lag1_diff"]),
}
)
if not rows:
raise SystemExit("no rows in history file: %s" % hist_path)
rows = sorted(rows, key=lambda x: x["seed"])
seeds = [str(r["seed"]) for r in rows]
metrics = [
("avg_ks", "KS (continuous)"),
("avg_jsd", "JSD (discrete)"),
("avg_lag1_diff", "Abs Δ lag-1 autocorr"),
]
if args.out:
out_path = Path(args.out)
else:
if args.figure == "panel":
out_path = Path(__file__).resolve().parent.parent / "figures" / "benchmark_panel.svg"
else:
out_path = Path(__file__).resolve().parent.parent / "figures" / "benchmark_metrics.svg"
if args.figure == "summary":
if args.engine in {"auto", "matplotlib"}:
try:
plot_matplotlib(rows, seeds, metrics, out_path)
print("saved", out_path)
return
except Exception:
if args.engine == "matplotlib":
raise
plot_svg(rows, seeds, metrics, out_path)
print("saved", out_path)
return
ks_rows = read_csv_rows(args.ks_per_feature)
shift_rows = read_csv_rows(args.data_shift)
mh_rows = read_csv_rows(args.metrics_history)
fm = read_json(args.filtered_metrics)
bh_rows = [{"seed": r["seed"], "avg_ks": r["avg_ks"], "avg_jsd": r["avg_jsd"], "avg_lag1_diff": r["avg_lag1_diff"]} for r in rows]
if args.engine in {"auto", "matplotlib"}:
try:
panel_matplotlib(bh_rows, ks_rows, shift_rows, mh_rows, fm, out_path)
print("saved", out_path)
return
except Exception:
if args.engine == "matplotlib":
raise
panel_svg(bh_rows, ks_rows, shift_rows, mh_rows, fm, out_path)
print("saved", out_path)
if __name__ == "__main__":
main()