diff --git a/docs/范文.pdf b/docs/范文.pdf
new file mode 100644
index 0000000..7bafb1c
Binary files /dev/null and b/docs/范文.pdf differ
diff --git a/example/config.json b/example/config.json
index 255524b..691ae9d 100644
--- a/example/config.json
+++ b/example/config.json
@@ -51,9 +51,9 @@
"full_stats": true,
"type1_features": ["P1_B4002","P2_MSD","P4_HT_LD","P1_B2004","P1_B3004","P1_B4022","P1_B3005"],
"type2_features": ["P1_B4005"],
- "type3_features": ["P1_PCV02Z","P1_PCV01Z","P1_PCV01D","P1_FCV02Z"],
+ "type3_features": ["P1_PCV02Z","P1_PCV01Z","P1_PCV01D","P1_FCV02Z","P1_FCV03D","P1_FCV03Z","P1_LCV01D","P1_LCV01Z"],
"type4_features": ["P1_PIT02","P2_SIT02","P1_FT03"],
- "type5_features": ["P1_FT03Z"],
+ "type5_features": ["P1_FT03Z","P1_FT02Z"],
"type6_features": ["P4_HT_PO","P2_24Vdc","P2_HILout"],
"shuffle_buffer": 256,
"use_temporal_stage1": true,
diff --git a/example/plot_benchmark.py b/example/plot_benchmark.py
index a32dfb1..4baa3b1 100644
--- a/example/plot_benchmark.py
+++ b/example/plot_benchmark.py
@@ -12,15 +12,37 @@ def parse_args():
base_dir = Path(__file__).resolve().parent
parser.add_argument(
"--figure",
- choices=["panel", "summary"],
+ choices=["panel", "summary", "ranked_ks", "lines", "cdf_grid", "disc_grid", "disc_points"],
default="panel",
- help="Figure type: panel (paper-style multi-panel) or summary (seed robustness only).",
+ help="Figure type: panel (paper-style multi-panel), summary (seed robustness only), ranked_ks (outlier attribution), lines (feature series), or cdf_grid (distributions).",
)
parser.add_argument(
"--history",
default=str(base_dir / "results" / "benchmark_history.csv"),
help="Path to benchmark_history.csv",
)
+ parser.add_argument(
+ "--generated",
+ default=str(base_dir / "results" / "generated.csv"),
+ help="Path to generated.csv (used for per-feature profile in panel A).",
+ )
+ parser.add_argument(
+ "--cont-stats",
+ default=str(base_dir / "results" / "cont_stats.json"),
+ help="Path to cont_stats.json (used for per-feature profile in panel A).",
+ )
+ parser.add_argument(
+ "--profile-order",
+ choices=["ks_desc", "name"],
+ default="ks_desc",
+ help="Feature ordering for panel A profile: ks_desc or name.",
+ )
+ parser.add_argument(
+ "--profile-max-features",
+ type=int,
+ default=64,
+ help="Max features in panel A profile (0 means all).",
+ )
parser.add_argument(
"--ks-per-feature",
default=str(base_dir / "results" / "ks_per_feature.csv"),
@@ -41,6 +63,17 @@ def parse_args():
default=str(base_dir / "results" / "filtered_metrics.json"),
help="Path to filtered_metrics.json (optional).",
)
+ parser.add_argument(
+ "--ranked-ks",
+ default=str(base_dir / "results" / "ranked_ks.csv"),
+ help="Path to ranked_ks.csv (used for --figure ranked_ks).",
+ )
+ parser.add_argument(
+ "--ranked-ks-top-n",
+ type=int,
+ default=30,
+ help="Number of top KS features to show in ranked_ks figure.",
+ )
parser.add_argument(
"--out",
default="",
@@ -52,6 +85,73 @@ def parse_args():
default="auto",
help="Plotting engine: auto prefers matplotlib if available; otherwise uses pure-SVG.",
)
+ parser.add_argument(
+ "--lines-features",
+ default="",
+ help="Comma-separated feature names for --figure lines (default: top-4 from ranked_ks.csv or fallback set).",
+ )
+ parser.add_argument(
+ "--lines-top-k",
+ type=int,
+ default=8,
+ help="When features not specified, take top-K features from ranked_ks.csv.",
+ )
+ parser.add_argument(
+ "--lines-max-rows",
+ type=int,
+ default=1000,
+ help="Max rows to read from generated.csv for --figure lines.",
+ )
+ parser.add_argument(
+ "--lines-normalize",
+ choices=["none", "real_range"],
+ default="none",
+ help="Normalization for --figure lines: none or real_range (use cont_stats min/max).",
+ )
+ parser.add_argument(
+ "--reference",
+ default=str(base_dir / "config.json"),
+ help="Reference source for real data: config.json with data_glob or direct CSV/GZ path.",
+ )
+ parser.add_argument(
+ "--lines-ref-index",
+ type=int,
+ default=0,
+ help="Index of matched reference file to plot (0-based) when using a glob.",
+ )
+ parser.add_argument(
+ "--cdf-features",
+ default="",
+ help="Comma-separated features for cdf_grid; empty = all continuous features from cont_stats.",
+ )
+ parser.add_argument(
+ "--cdf-max-features",
+ type=int,
+ default=64,
+ help="Max features to include in cdf_grid.",
+ )
+ parser.add_argument(
+ "--cdf-bins",
+ type=int,
+ default=80,
+ help="Number of bins for empirical CDF.",
+ )
+ parser.add_argument(
+ "--disc-features",
+ default="",
+ help="Comma-separated features for disc_grid; empty = use discrete list from feature_split.json.",
+ )
+ parser.add_argument(
+ "--disc-max-features",
+ type=int,
+ default=64,
+ help="Max features to include in disc_grid.",
+ )
+ parser.add_argument(
+ "--feature-split",
+ default=str(base_dir / "feature_split.json"),
+ help="Path to feature_split.json with 'continuous' and 'discrete' lists.",
+ )
return parser.parse_args()
@@ -181,6 +281,109 @@ def plot_matplotlib(rows, seeds, metrics, out_path):
fig.savefig(out_path, format="svg")
plt.close(fig)
+def discrete_points_matplotlib(generated_csv_path, reference_arg, features, out_path, max_rows=5000):
+ import matplotlib.pyplot as plt
+ import numpy as np
+ import gzip
+ try:
+ plt.style.use("seaborn-v0_8-whitegrid")
+ except Exception:
+ pass
+ def resolve_reference_glob(ref_arg: str) -> str:
+ ref_path = Path(ref_arg)
+ if ref_path.suffix == ".json":
+ cfg = json.loads(ref_path.read_text(encoding="utf-8"))
+ data_glob = cfg.get("data_glob") or cfg.get("data_path") or ""
+ if not data_glob:
+ raise SystemExit("reference config has no data_glob/data_path")
+ combined = ref_path.parent / data_glob
+ if "*" in str(combined) or "?" in str(combined):
+ return str(combined)
+ return str(combined.resolve())
+ return str(ref_path)
+ def read_rows(path, limit):
+ rows = []
+ opener = gzip.open if str(path).endswith(".gz") else open
+ with opener(path, "rt", newline="") as fh:
+ reader = csv.DictReader(fh)
+ for i, r in enumerate(reader):
+ rows.append(r)
+ if limit > 0 and i + 1 >= limit:
+ break
+ return rows
+ def cats(rows, feat):
+ vs = []
+ for r in rows:
+ v = r.get(feat)
+ if v is None:
+ continue
+ s = str(v).strip()
+ if s == "" or s.lower() == "nan":
+ continue
+ vs.append(s)
+ return vs
+ ref_glob = resolve_reference_glob(reference_arg)
+ ref_paths = sorted(Path(ref_glob).parent.glob(Path(ref_glob).name))
+ gen_rows = read_rows(generated_csv_path, max_rows)
+ ref_rows = read_rows(ref_paths[0] if ref_paths else generated_csv_path, max_rows)
+ points = []
+ group_spans = []
+ x = 0
+ for feat in features:
+ gvs = cats(gen_rows, feat)
+ rvs = cats(ref_rows, feat)
+ cats_all = sorted(set(gvs) | set(rvs))
+ start_x = x
+ for c in cats_all:
+ g_count = sum(1 for v in gvs if v == c)
+ r_count = sum(1 for v in rvs if v == c)
+ g_total = len(gvs) or 1
+ r_total = len(rvs) or 1
+ g_p = g_count / g_total
+ r_p = r_count / r_total
+ points.append({"x": x, "feat": feat, "cat": c, "g": g_p, "r": r_p})
+ x += 1
+ end_x = x - 1
+ if end_x >= start_x:
+ group_spans.append({"feat": feat, "x0": start_x, "x1": end_x})
+ x += 1
+ if not points:
+ fig, ax = plt.subplots(figsize=(9, 3))
+ ax.axis("off")
+ ax.text(0.5, 0.5, "no discrete data", ha="center", va="center")
+ out_path.parent.mkdir(parents=True, exist_ok=True)
+ fig.savefig(out_path, format="svg")
+ plt.close(fig)
+ return
+ width = max(9.0, min(18.0, 0.08 * len(points)))
+ fig, ax = plt.subplots(figsize=(width, 6.2))
+ xs = np.array([p["x"] for p in points], dtype=float)
+ jitter = 0.18
+ ax.scatter(xs - jitter, [p["g"] for p in points], s=26, color="#2563eb", alpha=0.85, edgecolors="white", linewidths=0.6, label="generated")
+ ax.scatter(xs + jitter, [p["r"] for p in points], s=26, color="#ef4444", alpha=0.75, edgecolors="white", linewidths=0.6, label="real")
+ for span in group_spans:
+ xc = (span["x0"] + span["x1"]) / 2.0
+ ax.axvline(span["x1"] + 0.5, color="#e5e7eb", lw=1.0)
+ ax.text(xc, -0.06, span["feat"], ha="center", va="top", rotation=25, fontsize=8, color="#374151")
+ xg_all = xs - jitter
+ yg_all = [p["g"] for p in points]
+ xr_all = xs + jitter
+ yr_all = [p["r"] for p in points]
+ ax.plot(xg_all, yg_all, color="#2563eb", linewidth=1.2, alpha=0.85)
+ ax.plot(xr_all, yr_all, color="#ef4444", linewidth=1.2, alpha=0.75)
+ ax.fill_between(xg_all, yg_all, 0.0, color="#2563eb", alpha=0.14, step=None)
+ ax.fill_between(xr_all, yr_all, 0.0, color="#ef4444", alpha=0.14, step=None)
+ ax.set_ylim(-0.08, 1.02)
+ ax.set_xlim(-0.5, max(xs) + 0.5)
+ ax.set_ylabel("probability", fontsize=10)
+ ax.set_xticks([])
+ ax.grid(True, axis="y", color="#e5e7eb")
+ ax.legend(loc="upper right", fontsize=9)
+ fig.suptitle("Discrete Marginals (dot plot): generated vs real", fontsize=12, color="#111827", y=0.98)
+ fig.tight_layout(rect=(0, 0, 1, 0.96))
+ out_path.parent.mkdir(parents=True, exist_ok=True)
+ fig.savefig(out_path, format="svg")
+ plt.close(fig)
def plot_svg(rows, seeds, metrics, out_path):
W, H = 980, 440
@@ -305,6 +508,489 @@ def plot_svg(rows, seeds, metrics, out_path):
out_path.write_text("\n".join(parts), encoding="utf-8")
+def ranked_ks_svg(ranked_rows, out_path, top_n=30):
+ parsed = []
+ for r in ranked_rows or []:
+ feat = (r.get("feature") or "").strip()
+ if not feat:
+ continue
+ ks = parse_float(r.get("ks"))
+ if ks is None:
+ continue
+ rank = parse_float(r.get("rank"))
+ contrib = parse_float(r.get("contribution_to_avg"))
+ avg_if = parse_float(r.get("avg_ks_if_remove_top_n"))
+ parsed.append(
+ {
+ "rank": int(rank) if isinstance(rank, (int, float)) else None,
+ "feature": feat,
+ "ks": float(ks),
+ "contrib": float(contrib) if isinstance(contrib, (int, float)) else None,
+ "avg_if": float(avg_if) if isinstance(avg_if, (int, float)) else None,
+ }
+ )
+
+ if not parsed:
+ raise SystemExit("no valid rows in ranked_ks.csv")
+
+ parsed_by_ks = sorted(parsed, key=lambda x: x["ks"], reverse=True)
+ top_n = max(1, int(top_n))
+ top = parsed_by_ks[: min(top_n, len(parsed_by_ks))]
+ parsed_by_rank = sorted(
+ [p for p in parsed if isinstance(p.get("rank"), int) and p.get("avg_if") is not None],
+ key=lambda x: x["rank"],
+ )
+
+ baseline = None
+ if parsed_by_rank:
+ first = parsed_by_rank[0]
+ if first.get("avg_if") is not None and first.get("contrib") is not None:
+ baseline = first["avg_if"] + first["contrib"]
+ elif first.get("avg_if") is not None:
+ baseline = first["avg_if"]
+ if baseline is None:
+ baseline = sum(p["ks"] for p in parsed_by_ks) / len(parsed_by_ks)
+
+ xs = [0]
+ ys = [baseline]
+ for p in parsed_by_rank:
+ xs.append(p["rank"])
+ ys.append(p["avg_if"])
+
+ W, H = 980, 560
+ bg = "#ffffff"
+ ink = "#0f172a"
+ subtle = "#64748b"
+ grid = "#e2e8f0"
+ border = "#cbd5e1"
+ blue = "#2563eb"
+ bar = "#0ea5e9"
+
+ ff = "system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif"
+
+ def text(x, y, s, size=12, anchor="start", color=ink, weight="normal"):
+ return "{t}".format(
+ x=x,
+ y=y,
+ a=anchor,
+ ff=ff,
+ fs=size,
+ c=color,
+ w=weight,
+ t=svg_escape(s),
+ )
+
+ def line(x1, y1, x2, y2, color=grid, width=1.0, dash=None, opacity=1.0, cap="round"):
+ extra = ""
+ if dash:
+ extra += " stroke-dasharray='{d}'".format(d=dash)
+ if opacity != 1.0:
+ extra += " stroke-opacity='{o}'".format(o=opacity)
+ return "".format(
+ x1=x1,
+ y1=y1,
+ x2=x2,
+ y2=y2,
+ c=color,
+ w=width,
+ cap=cap,
+ extra=extra,
+ )
+
+ pad = 26
+ title_h = 62
+ gap = 18
+ card_w = (W - 2 * pad - gap) / 2
+ card_h = H - pad - title_h
+ xA = pad
+ xB = pad + card_w + gap
+ y0 = title_h
+
+ parts = []
+ parts.append("")
+ out_path.parent.mkdir(parents=True, exist_ok=True)
+ out_path.write_text("\n".join(parts), encoding="utf-8")
+
+
+def lines_matplotlib(generated_csv_path, cont_stats, features, out_path, max_rows=1000, normalize="none", reference_arg="", ref_index=0):
+ import matplotlib.pyplot as plt
+ import gzip
+ try:
+ plt.style.use("seaborn-v0_8-whitegrid")
+ except Exception:
+ pass
+
+ rows_gen = []
+ with Path(generated_csv_path).open("r", encoding="utf-8", newline="") as f:
+ reader = csv.DictReader(f)
+ for i, r in enumerate(reader):
+ rows_gen.append(r)
+ if len(rows_gen) >= max_rows:
+ break
+
+ xs_gen = [parse_float(r.get("time")) or (i if (r.get("time") is None) else 0.0) for i, r in enumerate(rows_gen)]
+
+ mins = {}
+ maxs = {}
+ if isinstance(cont_stats, dict):
+ mins = cont_stats.get("min", {}) or {}
+ maxs = cont_stats.get("max", {}) or {}
+
+ def norm_val(feat, v):
+ if normalize != "real_range":
+ return v
+ mn = parse_float(mins.get(feat))
+ mx = parse_float(maxs.get(feat))
+ if mn is None or mx is None:
+ return v
+ denom = mx - mn
+ if denom == 0:
+ return v
+ return (v - mn) / denom
+
+ def resolve_reference_glob(ref_arg: str) -> str:
+ ref_path = Path(ref_arg)
+ if ref_path.suffix == ".json":
+ cfg = json.loads(ref_path.read_text(encoding="utf-8"))
+ data_glob = cfg.get("data_glob") or cfg.get("data_path") or ""
+ if not data_glob:
+ raise SystemExit("reference config has no data_glob/data_path")
+ combined = ref_path.parent / data_glob
+ if "*" in str(combined) or "?" in str(combined):
+ return str(combined)
+ return str(combined.resolve())
+ return str(ref_path)
+
+ def read_series(path: Path, cols, max_rows: int):
+ vals = {c: [] for c in cols}
+ opener = gzip.open if str(path).endswith(".gz") else open
+ with opener(path, "rt", newline="") as fh:
+ reader = csv.DictReader(fh)
+ for i, row in enumerate(reader):
+ for c in cols:
+ try:
+ vals[c].append(float(row[c]))
+ except Exception:
+ pass
+ if max_rows > 0 and i + 1 >= max_rows:
+ break
+ return vals
+
+ ref_glob = resolve_reference_glob(reference_arg or str(Path(__file__).resolve().parent / "config.json"))
+ ref_paths = sorted(Path(ref_glob).parent.glob(Path(ref_glob).name))
+ ref_rows = []
+ if ref_paths:
+ idx = max(0, min(ref_index, len(ref_paths) - 1))
+ first = ref_paths[idx]
+ opener = gzip.open if str(first).endswith(".gz") else open
+ with opener(first, "rt", newline="") as fh:
+ reader = csv.DictReader(fh)
+ for i, r in enumerate(reader):
+ ref_rows.append(r)
+ if len(ref_rows) >= max_rows:
+ break
+ xs_ref = [i for i, _ in enumerate(ref_rows)]
+
+ fig, axes = plt.subplots(nrows=len(features), ncols=1, figsize=(9.2, 6.6), sharex=True)
+ if len(features) == 1:
+ axes = [axes]
+
+ for ax, feat in zip(axes, features, strict=True):
+ ys_gen = [norm_val(feat, parse_float(r.get(feat)) or 0.0) for r in rows_gen]
+ ax.plot(xs_gen, ys_gen, color="#2563eb", linewidth=1.6, label="generated")
+ if ref_rows:
+ ys_ref = [norm_val(feat, parse_float(r.get(feat)) or 0.0) for r in ref_rows]
+ ax.plot(xs_ref, ys_ref, color="#ef4444", linewidth=1.2, alpha=0.75, label="real")
+ ax.set_ylabel(feat, fontsize=10)
+ ax.grid(True, color="#e5e7eb")
+ ax.legend(loc="upper right", fontsize=8)
+
+ axes[-1].set_xlabel("time", fontsize=10)
+ fig.suptitle("Feature Series: generated vs real", fontsize=12, color="#111827", y=0.98)
+ fig.tight_layout(rect=(0, 0, 1, 0.96))
+
+ out_path.parent.mkdir(parents=True, exist_ok=True)
+ fig.savefig(out_path, format="svg")
+ plt.close(fig)
+
+def cdf_grid_matplotlib(generated_csv_path, reference_arg, cont_stats, features, out_path, max_rows=5000, bins=80):
+ import matplotlib.pyplot as plt
+ import gzip
+ import numpy as np
+ try:
+ plt.style.use("seaborn-v0_8-whitegrid")
+ except Exception:
+ pass
+ mins = cont_stats.get("min", {}) if isinstance(cont_stats, dict) else {}
+ maxs = cont_stats.get("max", {}) if isinstance(cont_stats, dict) else {}
+ def resolve_reference_glob(ref_arg: str) -> str:
+ ref_path = Path(ref_arg)
+ if ref_path.suffix == ".json":
+ cfg = json.loads(ref_path.read_text(encoding="utf-8"))
+ data_glob = cfg.get("data_glob") or cfg.get("data_path") or ""
+ if not data_glob:
+ raise SystemExit("reference config has no data_glob/data_path")
+ combined = ref_path.parent / data_glob
+ if "*" in str(combined) or "?" in str(combined):
+ return str(combined)
+ return str(combined.resolve())
+ return str(ref_path)
+ def read_rows(path, limit):
+ rows = []
+ opener = gzip.open if str(path).endswith(".gz") else open
+ with opener(path, "rt", newline="") as fh:
+ reader = csv.DictReader(fh)
+ for i, r in enumerate(reader):
+ rows.append(r)
+ if limit > 0 and i + 1 >= limit:
+ break
+ return rows
+ gen_rows = read_rows(generated_csv_path, max_rows)
+ ref_glob = resolve_reference_glob(reference_arg)
+ ref_paths = sorted(Path(ref_glob).parent.glob(Path(ref_glob).name))
+ ref_rows = read_rows(ref_paths[0] if ref_paths else generated_csv_path, max_rows)
+ def values(rows, feat):
+ vs = []
+ for r in rows:
+ x = parse_float(r.get(feat))
+ if x is not None:
+ vs.append(x)
+ return vs
+ cols = 4
+ rows_n = int(math.ceil(len(features) / cols)) if features else 1
+ fig, axes = plt.subplots(nrows=rows_n, ncols=cols, figsize=(cols * 3.2, rows_n * 2.6))
+ axes = np.array(axes).reshape(rows_n, cols)
+ for i, feat in enumerate(features):
+ rr = i // cols
+ cc = i % cols
+ ax = axes[rr][cc]
+ gvs = values(gen_rows, feat)
+ rvs = values(ref_rows, feat)
+ mn = parse_float(mins.get(feat))
+ mx = parse_float(maxs.get(feat))
+ if mn is None or mx is None or mx <= mn:
+ lo = min(gvs + rvs) if (gvs or rvs) else 0.0
+ hi = max(gvs + rvs) if (gvs or rvs) else (lo + 1.0)
+ else:
+ lo = mn
+ hi = mx
+ edges = np.linspace(lo, hi, bins + 1)
+ def ecdf(vs):
+ if not vs:
+ return edges[1:], np.zeros_like(edges[1:])
+ hist, _ = np.histogram(vs, bins=edges)
+ cdf = np.cumsum(hist).astype(float)
+ cdf /= cdf[-1] if cdf[-1] > 0 else 1.0
+ xs = edges[1:]
+ return xs, cdf
+ xg, yg = ecdf(gvs)
+ xr, yr = ecdf(rvs)
+ ax.plot(xg, yg, color="#2563eb", linewidth=1.6, label="generated")
+ ax.plot(xr, yr, color="#ef4444", linewidth=1.2, alpha=0.85, label="real")
+ ax.set_title(feat, fontsize=9, loc="left")
+ ax.set_ylim(0, 1)
+ ax.grid(True, color="#e5e7eb")
+ for j in range(i + 1, rows_n * cols):
+ rr = j // cols
+ cc = j % cols
+ axes[rr][cc].axis("off")
+ handles, labels = axes[0][0].get_legend_handles_labels()
+ fig.legend(handles, labels, loc="upper center", ncol=2, fontsize=9)
+ fig.suptitle("Empirical CDF: generated vs real", fontsize=12, color="#111827", y=0.98)
+ fig.tight_layout(rect=(0, 0, 1, 0.96))
+ out_path.parent.mkdir(parents=True, exist_ok=True)
+ fig.savefig(out_path, format="svg")
+ plt.close(fig)
+
+def discrete_grid_matplotlib(generated_csv_path, reference_arg, features, out_path, max_rows=5000):
+ import matplotlib.pyplot as plt
+ import numpy as np
+ import gzip
+ try:
+ plt.style.use("seaborn-v0_8-whitegrid")
+ except Exception:
+ pass
+ def resolve_reference_glob(ref_arg: str) -> str:
+ ref_path = Path(ref_arg)
+ if ref_path.suffix == ".json":
+ cfg = json.loads(ref_path.read_text(encoding="utf-8"))
+ data_glob = cfg.get("data_glob") or cfg.get("data_path") or ""
+ if not data_glob:
+ raise SystemExit("reference config has no data_glob/data_path")
+ combined = ref_path.parent / data_glob
+ if "*" in str(combined) or "?" in str(combined):
+ return str(combined)
+ return str(combined.resolve())
+ return str(ref_path)
+ def read_rows(path, limit):
+ rows = []
+ opener = gzip.open if str(path).endswith(".gz") else open
+ with opener(path, "rt", newline="") as fh:
+ reader = csv.DictReader(fh)
+ for i, r in enumerate(reader):
+ rows.append(r)
+ if limit > 0 and i + 1 >= limit:
+ break
+ return rows
+ ref_glob = resolve_reference_glob(reference_arg)
+ ref_paths = sorted(Path(ref_glob).parent.glob(Path(ref_glob).name))
+ gen_rows = read_rows(generated_csv_path, max_rows)
+ ref_rows = read_rows(ref_paths[0] if ref_paths else generated_csv_path, max_rows)
+ def cats(rows, feat):
+ vs = []
+ for r in rows:
+ v = r.get(feat)
+ if v is None:
+ continue
+ s = str(v).strip()
+ if s == "" or s.lower() == "nan":
+ continue
+ vs.append(s)
+ return vs
+ cols = 4
+ rows_n = int(math.ceil(len(features) / cols)) if features else 1
+ fig, axes = plt.subplots(nrows=rows_n, ncols=cols, figsize=(cols * 3.0, rows_n * 2.6), sharey=False)
+ axes = np.array(axes).reshape(rows_n, cols)
+ for i, feat in enumerate(features):
+ rr = i // cols
+ cc = i % cols
+ ax = axes[rr][cc]
+ gvs = cats(gen_rows, feat)
+ rvs = cats(ref_rows, feat)
+ all_vals = sorted(set(gvs) | set(rvs))
+ if not all_vals:
+ ax.axis("off")
+ continue
+ g_counts = {v: 0 for v in all_vals}
+ r_counts = {v: 0 for v in all_vals}
+ for v in gvs:
+ g_counts[v] += 1
+ for v in rvs:
+ r_counts[v] += 1
+ g_total = sum(g_counts.values()) or 1
+ r_total = sum(r_counts.values()) or 1
+ g_p = [g_counts[v] / g_total for v in all_vals]
+ r_p = [r_counts[v] / r_total for v in all_vals]
+ x = np.arange(len(all_vals))
+ w = 0.42
+ ax.bar(x - w / 2, g_p, width=w, color="#2563eb", alpha=0.85, label="generated")
+ ax.bar(x + w / 2, r_p, width=w, color="#ef4444", alpha=0.75, label="real")
+ ax.set_title(feat, fontsize=9, loc="left")
+ ax.set_xticks(x)
+ ax.set_xticklabels(all_vals, rotation=25, ha="right", fontsize=8)
+ ax.set_ylim(0, 1)
+ ax.grid(True, axis="y", color="#e5e7eb")
+ for j in range(i + 1, rows_n * cols):
+ rr = j // cols
+ cc = j % cols
+ axes[rr][cc].axis("off")
+ handles, labels = axes[0][0].get_legend_handles_labels()
+ fig.legend(handles, labels, loc="upper center", ncol=2, fontsize=9)
+ fig.suptitle("Discrete Marginals: generated vs real", fontsize=12, color="#111827", y=0.98)
+ fig.tight_layout(rect=(0, 0, 1, 0.96))
+ out_path.parent.mkdir(parents=True, exist_ok=True)
+ fig.savefig(out_path, format="svg")
+ plt.close(fig)
def read_csv_rows(path):
p = Path(path)
if not p.exists():
@@ -330,6 +1016,126 @@ def parse_float(s):
return float(ss)
+def compute_csv_means(path):
+ p = Path(path)
+ if not p.exists():
+ return {}
+ with p.open("r", encoding="utf-8", newline="") as f:
+ reader = csv.DictReader(f)
+ sums = {}
+ counts = {}
+ for r in reader:
+ for k, v in r.items():
+ x = parse_float(v)
+ if x is None:
+ continue
+ sums[k] = sums.get(k, 0.0) + x
+ counts[k] = counts.get(k, 0) + 1
+ means = {}
+ for k, s in sums.items():
+ c = counts.get(k, 0)
+ if c > 0:
+ means[k] = s / c
+ return means
+
+
+def build_mean_profile(ks_rows, cont_stats, generated_csv_path, order="ks_desc", max_features=64):
+ if not isinstance(cont_stats, dict):
+ return None
+ means_real = cont_stats.get("mean")
+ if not isinstance(means_real, dict):
+ return None
+
+ means_gen = compute_csv_means(generated_csv_path)
+ if not means_gen:
+ return None
+
+ ks_by_feat = {}
+ mins = {}
+ maxs = {}
+ for r in ks_rows or []:
+ feat = (r.get("feature") or "").strip()
+ if not feat:
+ continue
+ ks = parse_float(r.get("ks"))
+ mn = parse_float(r.get("real_min"))
+ mx = parse_float(r.get("real_max"))
+ if ks is not None:
+ ks_by_feat[feat] = ks
+ if mn is not None:
+ mins[feat] = mn
+ if mx is not None:
+ maxs[feat] = mx
+
+ feats = []
+ real_vals = []
+ gen_vals = []
+ ks_vals = []
+ for feat, mu_real in means_real.items():
+ if feat not in means_gen:
+ continue
+ if feat not in mins or feat not in maxs:
+ continue
+ mn = mins[feat]
+ mx = maxs[feat]
+ denom = mx - mn
+ if denom == 0:
+ continue
+ mu_gen = means_gen[feat]
+ feats.append(feat)
+ real_vals.append(clamp((mu_real - mn) / denom, 0.0, 1.0))
+ gen_vals.append(clamp((mu_gen - mn) / denom, 0.0, 1.0))
+ ks_vals.append(ks_by_feat.get(feat))
+
+ if not feats:
+ return None
+
+ idx = list(range(len(feats)))
+ if order == "name":
+ idx.sort(key=lambda i: feats[i])
+ else:
+ idx.sort(key=lambda i: ks_by_feat.get(feats[i], -1.0), reverse=True)
+
+ if isinstance(max_features, int) and max_features > 0 and len(idx) > max_features:
+ idx = idx[:max_features]
+
+ sel_feats = [feats[i] for i in idx]
+ sel_real = [real_vals[i] for i in idx]
+ sel_gen = [gen_vals[i] for i in idx]
+ sel_ks = [ks_vals[i] for i in idx]
+ sel_diff = [abs(a - b) for a, b in zip(sel_real, sel_gen, strict=True)]
+
+ def pearsonr(x, y):
+ if not x or len(x) != len(y):
+ return None
+ n = len(x)
+ mx = sum(x) / n
+ my = sum(y) / n
+ vx = sum((xi - mx) ** 2 for xi in x)
+ vy = sum((yi - my) ** 2 for yi in y)
+ if vx <= 0 or vy <= 0:
+ return None
+ cov = sum((xi - mx) * (yi - my) for xi, yi in zip(x, y, strict=True))
+ return cov / math.sqrt(vx * vy)
+
+ r = pearsonr(sel_real, sel_gen)
+ mae = (sum(sel_diff) / len(sel_diff)) if sel_diff else None
+
+ points = []
+ for f, xr, yg, ks, d in zip(sel_feats, sel_real, sel_gen, sel_ks, sel_diff, strict=True):
+ points.append({"feature": f, "x": xr, "y": yg, "ks": ks, "diff": d})
+
+ return {
+ "features": sel_feats,
+ "real": sel_real,
+ "gen": sel_gen,
+ "ks": sel_ks,
+ "diff": sel_diff,
+ "points": points,
+ "stats": {"n": len(sel_feats), "r": r, "mae": mae},
+ }
+
+
def zscores(vals):
if not vals:
return []
@@ -341,7 +1147,7 @@ def zscores(vals):
return [(x - m) / s for x in vals]
-def panel_svg(bh_rows, ks_rows, shift_rows, hist_rows, filtered_metrics, out_path):
+def panel_svg(bh_rows, ks_rows, shift_rows, hist_rows, filtered_metrics, profile, out_path):
W, H = 1400, 900
margin = 42
gap = 26
@@ -357,15 +1163,36 @@ def panel_svg(bh_rows, ks_rows, shift_rows, hist_rows, filtered_metrics, out_pat
blue = "#3b82f6"
red = "#ef4444"
green = "#10b981"
+ card_bg = "#ffffff"
+ card_shadow = "#0f172a"
+ plot_bg = "#f8fafc"
def panel_rect(x, y):
- return "".format(
- x=x, y=y, w=panel_w, h=panel_h, b=border
+ return (
+ ""
+ ""
+ ).format(
+ x=x + 2.0,
+ y=y + 3.0,
+ x0=x,
+ y0=y,
+ w=panel_w,
+ h=panel_h,
+ f=card_bg,
+ b=border,
+ s=card_shadow,
)
def text(x, y, s, size=12, anchor="start", color=ink, weight="normal"):
- return "{t}".format(
- x=x, y=y, a=anchor, fs=size, c=color, w=weight, t=svg_escape(s)
+ return "{t}".format(
+ x=x,
+ y=y,
+ a=anchor,
+ ff="system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif",
+ fs=size,
+ c=color,
+ w=weight,
+ t=svg_escape(s),
)
def line(x1, y1, x2, y2, color=border, width=1.0, dash=None, opacity=1.0, cap="round"):
@@ -403,13 +1230,20 @@ def panel_svg(bh_rows, ks_rows, shift_rows, hist_rows, filtered_metrics, out_pat
"