572 lines
21 KiB
Python
572 lines
21 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Benchmark visualization for per-feature KS / JSD / lag-1 diff + overall averages.
|
||
|
||
- Uses dummy data generators (AR(1) for continuous; Markov chain for discrete)
|
||
- Replace `real_df` and `gen_df` with your real/generated samples.
|
||
- Outputs publication-quality figures into ./benchmark_figs/
|
||
|
||
Dependencies:
|
||
pip install numpy pandas matplotlib scipy
|
||
(Optional for nicer label wrapping)
|
||
pip install textwrap3
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import os
|
||
from dataclasses import dataclass
|
||
from typing import Dict, List, Tuple
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
from scipy.stats import ks_2samp
|
||
from scipy.spatial.distance import jensenshannon
|
||
|
||
|
||
# ----------------------------
|
||
# 1) Dummy data generation
|
||
# ----------------------------
|
||
|
||
def ar1_series(T: int, phi: float, sigma: float, mu: float = 0.0, seed: int | None = None) -> np.ndarray:
|
||
rng = np.random.default_rng(seed)
|
||
x = np.zeros(T, dtype=np.float64)
|
||
eps = rng.normal(0.0, sigma, size=T)
|
||
x[0] = mu + eps[0]
|
||
for t in range(1, T):
|
||
x[t] = mu + phi * (x[t - 1] - mu) + eps[t]
|
||
return x
|
||
|
||
def markov_chain(T: int, P: np.ndarray, pi0: np.ndarray, seed: int | None = None) -> np.ndarray:
|
||
"""Return integer states 0..K-1."""
|
||
rng = np.random.default_rng(seed)
|
||
K = P.shape[0]
|
||
states = np.zeros(T, dtype=np.int64)
|
||
states[0] = rng.choice(K, p=pi0)
|
||
for t in range(1, T):
|
||
states[t] = rng.choice(K, p=P[states[t - 1]])
|
||
return states
|
||
|
||
def make_dummy_real_gen(
|
||
T: int = 8000,
|
||
n_cont: int = 18,
|
||
n_disc: int = 6,
|
||
disc_vocab_sizes: List[int] | None = None,
|
||
seed: int = 7,
|
||
) -> Tuple[pd.DataFrame, pd.DataFrame, Dict[str, str], Dict[str, List[str]]]:
|
||
"""
|
||
Returns:
|
||
real_df, gen_df, feature_types, disc_vocab_map
|
||
where:
|
||
feature_types[col] in {"continuous","discrete"}
|
||
disc_vocab_map[col] = list of token strings (for discrete vars)
|
||
"""
|
||
rng = np.random.default_rng(seed)
|
||
|
||
cont_cols = [f"c{i:02d}" for i in range(n_cont)]
|
||
disc_cols = [f"d{i:02d}" for i in range(n_disc)]
|
||
|
||
if disc_vocab_sizes is None:
|
||
base = [3, 4, 5, 3, 6, 4]
|
||
disc_vocab_sizes = (base + [4] * max(0, n_disc - len(base)))[:n_disc]
|
||
elif len(disc_vocab_sizes) < n_disc:
|
||
# If you pass a shorter list, extend it safely
|
||
disc_vocab_sizes = list(disc_vocab_sizes) + [disc_vocab_sizes[-1]] * (n_disc - len(disc_vocab_sizes))
|
||
|
||
|
||
real = {}
|
||
gen = {}
|
||
|
||
# Continuous: AR(1) with slightly different parameters for generated
|
||
for i, col in enumerate(cont_cols):
|
||
phi_real = rng.uniform(0.75, 0.98)
|
||
sigma_real = rng.uniform(0.3, 1.2)
|
||
mu_real = rng.uniform(-1.0, 1.0)
|
||
|
||
# Generated is close but not identical
|
||
phi_gen = np.clip(phi_real + rng.normal(0, 0.02), 0.60, 0.995)
|
||
sigma_gen = max(0.05, sigma_real * rng.uniform(0.85, 1.15))
|
||
mu_gen = mu_real + rng.normal(0, 0.15)
|
||
|
||
real[col] = ar1_series(T, phi_real, sigma_real, mu_real, seed=seed + 100 + i)
|
||
gen[col] = ar1_series(T, phi_gen, sigma_gen, mu_gen, seed=seed + 200 + i)
|
||
|
||
# Discrete: Markov chains with slightly perturbed transition matrices
|
||
disc_vocab_map: Dict[str, List[str]] = {}
|
||
for j, col in enumerate(disc_cols):
|
||
K = disc_vocab_sizes[j]
|
||
disc_vocab_map[col] = [f"s{k}" for k in range(K)]
|
||
|
||
# Random but stable transition matrix for real
|
||
A = rng.uniform(0.1, 1.0, size=(K, K))
|
||
P_real = A / A.sum(axis=1, keepdims=True)
|
||
pi0 = rng.uniform(0.1, 1.0, size=K)
|
||
pi0 = pi0 / pi0.sum()
|
||
|
||
# Perturb for generated
|
||
noise = rng.normal(0, 0.06, size=(K, K))
|
||
P_gen = np.clip(P_real + noise, 1e-6, None)
|
||
P_gen = P_gen / P_gen.sum(axis=1, keepdims=True)
|
||
|
||
real_states = markov_chain(T, P_real, pi0, seed=seed + 300 + j)
|
||
gen_states = markov_chain(T, P_gen, pi0, seed=seed + 400 + j)
|
||
|
||
# Store as tokens (strings) to emphasize discrete nature
|
||
real[col] = np.array([disc_vocab_map[col][s] for s in real_states], dtype=object)
|
||
gen[col] = np.array([disc_vocab_map[col][s] for s in gen_states], dtype=object)
|
||
|
||
real_df = pd.DataFrame(real)
|
||
gen_df = pd.DataFrame(gen)
|
||
|
||
feature_types = {**{c: "continuous" for c in cont_cols}, **{d: "discrete" for d in disc_cols}}
|
||
return real_df, gen_df, feature_types, disc_vocab_map
|
||
|
||
|
||
# ----------------------------
|
||
# 2) Metric computation
|
||
# ----------------------------
|
||
|
||
def lag1_autocorr(x: np.ndarray) -> float:
|
||
"""Lag-1 Pearson autocorrelation (robust to constant arrays)."""
|
||
x = np.asarray(x, dtype=np.float64)
|
||
if x.size < 3:
|
||
return np.nan
|
||
x0 = x[:-1]
|
||
x1 = x[1:]
|
||
s0 = np.std(x0)
|
||
s1 = np.std(x1)
|
||
if s0 < 1e-12 or s1 < 1e-12:
|
||
return 0.0
|
||
return float(np.corrcoef(x0, x1)[0, 1])
|
||
|
||
def jsd_discrete(real_tokens: np.ndarray, gen_tokens: np.ndarray, vocab: List[str], base: float = 2.0) -> float:
|
||
"""Jensen–Shannon divergence for discrete distributions over vocab."""
|
||
# empirical pmf with Laplace smoothing to avoid zeros
|
||
eps = 1e-12
|
||
r = pd.Series(real_tokens).value_counts()
|
||
g = pd.Series(gen_tokens).value_counts()
|
||
p = np.array([r.get(v, 0) for v in vocab], dtype=np.float64)
|
||
q = np.array([g.get(v, 0) for v in vocab], dtype=np.float64)
|
||
p = p + eps
|
||
q = q + eps
|
||
p = p / p.sum()
|
||
q = q / q.sum()
|
||
# scipy returns sqrt(JS); square it to get JS divergence
|
||
return float(jensenshannon(p, q, base=base) ** 2)
|
||
|
||
@dataclass
|
||
class Metrics:
|
||
feature: str
|
||
ftype: str
|
||
ks: float | np.nan
|
||
jsd: float | np.nan
|
||
lag1_diff: float | np.nan
|
||
|
||
def compute_metrics(
|
||
real_df: pd.DataFrame,
|
||
gen_df: pd.DataFrame,
|
||
feature_types: Dict[str, str],
|
||
disc_vocab_map: Dict[str, List[str]] | None = None,
|
||
) -> pd.DataFrame:
|
||
rows: List[Metrics] = []
|
||
|
||
for col in real_df.columns:
|
||
ftype = feature_types[col]
|
||
r = real_df[col].to_numpy()
|
||
g = gen_df[col].to_numpy()
|
||
|
||
if ftype == "continuous":
|
||
# KS on raw values
|
||
ks_val = float(ks_2samp(r.astype(np.float64), g.astype(np.float64), alternative="two-sided").statistic)
|
||
# lag-1 diff
|
||
lag_r = lag1_autocorr(r)
|
||
lag_g = lag1_autocorr(g)
|
||
lag_diff = float(abs(lag_r - lag_g))
|
||
rows.append(Metrics(col, ftype, ks_val, np.nan, lag_diff))
|
||
|
||
elif ftype == "discrete":
|
||
# JSD on categorical pmf
|
||
if disc_vocab_map is None or col not in disc_vocab_map:
|
||
vocab = sorted(list(set(r).union(set(g))))
|
||
else:
|
||
vocab = disc_vocab_map[col]
|
||
jsd_val = jsd_discrete(r, g, vocab=vocab, base=2.0)
|
||
|
||
# optional: lag1 diff on integer encoding (captures “stickiness”; still interpretable)
|
||
# if you prefer lag1 only for continuous, set lag_diff=np.nan here.
|
||
mapping = {tok: i for i, tok in enumerate(vocab)}
|
||
r_int = np.array([mapping[t] for t in r], dtype=np.float64)
|
||
g_int = np.array([mapping[t] for t in g], dtype=np.float64)
|
||
lag_r = lag1_autocorr(r_int)
|
||
lag_g = lag1_autocorr(g_int)
|
||
lag_diff = float(abs(lag_r - lag_g))
|
||
|
||
rows.append(Metrics(col, ftype, np.nan, jsd_val, lag_diff))
|
||
else:
|
||
raise ValueError(f"Unknown feature type: {ftype}")
|
||
|
||
df = pd.DataFrame([r.__dict__ for r in rows])
|
||
|
||
# Overall (separate by metric meaning)
|
||
overall = {
|
||
"ks_avg_cont": float(df.loc[df.ftype == "continuous", "ks"].mean(skipna=True)),
|
||
"jsd_avg_disc": float(df.loc[df.ftype == "discrete", "jsd"].mean(skipna=True)),
|
||
"lag1_avg_all": float(df["lag1_diff"].mean(skipna=True)),
|
||
}
|
||
print("Overall averages:", overall)
|
||
return df
|
||
|
||
|
||
# ----------------------------
|
||
# 3) Fancy but honest visualizations
|
||
# ----------------------------
|
||
|
||
def _ensure_dir(p: str) -> None:
|
||
os.makedirs(p, exist_ok=True)
|
||
|
||
def plot_heatmap_metrics(metrics_df: pd.DataFrame, outdir: str, title: str = "Per-feature benchmark (normalized)") -> None:
|
||
"""
|
||
Heatmap of per-feature metrics.
|
||
We normalize each metric column to [0,1] so KS/JSD/lag1 are comparable visually.
|
||
"""
|
||
_ensure_dir(outdir)
|
||
|
||
df = metrics_df.copy()
|
||
|
||
# Create a unified table with columns present for all features
|
||
# For missing metric values (e.g., ks for discrete), keep NaN; normalize per metric ignoring NaNs.
|
||
show_cols = ["ks", "jsd", "lag1_diff"]
|
||
|
||
# sort features by "worst" normalized score (max across available metrics)
|
||
norm = df[show_cols].copy()
|
||
for c in show_cols:
|
||
col = norm[c]
|
||
if col.notna().sum() == 0:
|
||
continue
|
||
mn = float(col.min(skipna=True))
|
||
mx = float(col.max(skipna=True))
|
||
if abs(mx - mn) < 1e-12:
|
||
norm[c] = 0.0
|
||
else:
|
||
norm[c] = (col - mn) / (mx - mn)
|
||
df["_severity"] = norm.max(axis=1, skipna=True)
|
||
df = df.sort_values("_severity", ascending=False).drop(columns=["_severity"])
|
||
norm = norm.loc[df.index]
|
||
|
||
# Plot heatmap with matplotlib (no seaborn)
|
||
fig, ax = plt.subplots(figsize=(9.5, max(4.0, 0.28 * len(df))))
|
||
im = ax.imshow(norm.to_numpy(), aspect="auto", interpolation="nearest")
|
||
|
||
ax.set_yticks(np.arange(len(df)))
|
||
ax.set_yticklabels(df["feature"].tolist(), fontsize=8)
|
||
|
||
ax.set_xticks(np.arange(len(show_cols)))
|
||
ax.set_xticklabels(show_cols, fontsize=10)
|
||
|
||
ax.set_title(title, fontsize=13, pad=12)
|
||
cbar = fig.colorbar(im, ax=ax, fraction=0.035, pad=0.02)
|
||
cbar.set_label("min–max normalized (per metric)", rotation=90)
|
||
|
||
# Light gridlines for readability
|
||
ax.set_xticks(np.arange(-.5, len(show_cols), 1), minor=True)
|
||
ax.set_yticks(np.arange(-.5, len(df), 1), minor=True)
|
||
ax.grid(which="minor", color="white", linestyle="-", linewidth=0.6)
|
||
ax.tick_params(which="minor", bottom=False, left=False)
|
||
|
||
fig.tight_layout()
|
||
fig.savefig(os.path.join(outdir, "heatmap_metrics_normalized.png"), dpi=300)
|
||
plt.close(fig)
|
||
|
||
def plot_topk_lollipop(metrics_df: pd.DataFrame, metric: str, outdir: str, k: int = 15) -> None:
|
||
"""Fancy top-k lollipop chart for a metric (e.g., 'ks' or 'jsd' or 'lag1_diff')."""
|
||
_ensure_dir(outdir)
|
||
df = metrics_df.dropna(subset=[metric]).sort_values(metric, ascending=False).head(k)
|
||
|
||
fig, ax = plt.subplots(figsize=(9.5, 0.45 * len(df) + 1.5))
|
||
y = np.arange(len(df))[::-1]
|
||
x = df[metric].to_numpy()
|
||
|
||
# stems
|
||
ax.hlines(y=y, xmin=0, xmax=x, linewidth=2.0, alpha=0.85)
|
||
# heads
|
||
ax.plot(x, y, "o", markersize=7)
|
||
|
||
ax.set_yticks(y)
|
||
ax.set_yticklabels(df["feature"].tolist(), fontsize=9)
|
||
ax.set_xlabel(metric)
|
||
ax.set_title(f"Top-{k} features by {metric}", fontsize=13, pad=10)
|
||
|
||
# subtle x-grid
|
||
ax.grid(axis="x", linestyle="--", alpha=0.4)
|
||
fig.tight_layout()
|
||
fig.savefig(os.path.join(outdir, f"topk_{metric}_lollipop.png"), dpi=300)
|
||
plt.close(fig)
|
||
|
||
def plot_metric_distributions(metrics_df: pd.DataFrame, outdir: str) -> None:
|
||
"""Histogram distributions of each metric across features."""
|
||
_ensure_dir(outdir)
|
||
metrics = ["ks", "jsd", "lag1_diff"]
|
||
|
||
fig, ax = plt.subplots(figsize=(9.5, 5.2))
|
||
# overlay histograms (alpha) for compact comparison
|
||
for m in metrics:
|
||
vals = metrics_df[m].dropna().to_numpy()
|
||
if len(vals) == 0:
|
||
continue
|
||
ax.hist(vals, bins=20, alpha=0.55, density=True, label=m)
|
||
|
||
ax.set_title("Distribution of per-feature benchmark metrics", fontsize=13, pad=10)
|
||
ax.set_xlabel("metric value")
|
||
ax.set_ylabel("density")
|
||
ax.grid(axis="y", linestyle="--", alpha=0.35)
|
||
ax.legend()
|
||
fig.tight_layout()
|
||
fig.savefig(os.path.join(outdir, "metric_distributions.png"), dpi=300)
|
||
plt.close(fig)
|
||
|
||
def plot_overall_summary(metrics_df: pd.DataFrame, outdir: str) -> None:
|
||
"""Simple overall averages (separate meaning per metric)."""
|
||
_ensure_dir(outdir)
|
||
ks_avg = metrics_df.loc[metrics_df.ftype == "continuous", "ks"].mean(skipna=True)
|
||
jsd_avg = metrics_df.loc[metrics_df.ftype == "discrete", "jsd"].mean(skipna=True)
|
||
lag_avg = metrics_df["lag1_diff"].mean(skipna=True)
|
||
|
||
labels = ["KS(avg cont)", "JSD(avg disc)", "Lag1 diff(avg)"]
|
||
values = [ks_avg, jsd_avg, lag_avg]
|
||
|
||
fig, ax = plt.subplots(figsize=(7.8, 4.6))
|
||
ax.bar(labels, values)
|
||
ax.set_title("Overall benchmark summary", fontsize=13, pad=10)
|
||
ax.set_ylabel("value")
|
||
ax.grid(axis="y", linestyle="--", alpha=0.35)
|
||
fig.tight_layout()
|
||
fig.savefig(os.path.join(outdir, "overall_summary_bar.png"), dpi=300)
|
||
plt.close(fig)
|
||
|
||
|
||
# ----------------------------
|
||
# 4) continuous trends
|
||
# ----------------------------
|
||
|
||
|
||
def _smooth_ma(x: np.ndarray, w: int) -> np.ndarray:
|
||
"""Simple moving average smoothing; w>=1."""
|
||
if w <= 1:
|
||
return x.astype(np.float64, copy=True)
|
||
s = pd.Series(x.astype(np.float64))
|
||
return s.rolling(w, center=True, min_periods=max(2, w // 10)).mean().to_numpy()
|
||
|
||
def _rolling_quantile(x: np.ndarray, w: int, q: float) -> np.ndarray:
|
||
"""Rolling quantile for a local envelope."""
|
||
if w <= 1:
|
||
return x.astype(np.float64, copy=True)
|
||
s = pd.Series(x.astype(np.float64))
|
||
return s.rolling(w, center=True, min_periods=max(10, w // 5)).quantile(q).to_numpy()
|
||
|
||
def plot_trend_small_multiples(
|
||
real_df: pd.DataFrame,
|
||
gen_df: pd.DataFrame,
|
||
metrics_df: pd.DataFrame,
|
||
feature_types: Dict[str, str],
|
||
outdir: str,
|
||
k: int = 12,
|
||
rank_by: str = "lag1_diff", # or "ks"
|
||
smooth_w: int = 200,
|
||
env_w: int = 600,
|
||
downsample: int = 1,
|
||
) -> None:
|
||
"""
|
||
Small multiples of continuous-feature trends.
|
||
- Picks top-k continuous features by rank_by (ks or lag1_diff).
|
||
- Plots: faint raw + bold smoothed + rolling 10–90% envelope for real and gen.
|
||
"""
|
||
_ensure_dir(outdir)
|
||
|
||
# pick top-k continuous features
|
||
dfc = metrics_df[metrics_df["ftype"] == "continuous"].dropna(subset=[rank_by]).copy()
|
||
dfc = dfc.sort_values(rank_by, ascending=False).head(k)
|
||
feats = dfc["feature"].tolist()
|
||
if not feats:
|
||
print("[trend] No continuous features found to plot.")
|
||
return
|
||
|
||
n = len(feats)
|
||
ncols = 3
|
||
nrows = int(np.ceil(n / ncols))
|
||
|
||
fig = plt.subplots(figsize=(12, 3.2 * nrows))[0]
|
||
gs = fig.add_gridspec(nrows, ncols, hspace=0.35, wspace=0.22)
|
||
|
||
for idx, feat in enumerate(feats):
|
||
r = real_df[feat].to_numpy(dtype=np.float64)
|
||
g = gen_df[feat].to_numpy(dtype=np.float64)
|
||
|
||
if downsample > 1:
|
||
r = r[::downsample]
|
||
g = g[::downsample]
|
||
|
||
t = np.arange(len(r), dtype=np.int64)
|
||
|
||
# smoothed trend
|
||
r_tr = _smooth_ma(r, smooth_w)
|
||
g_tr = _smooth_ma(g, smooth_w)
|
||
|
||
# local envelope (10–90%)
|
||
r_lo = _rolling_quantile(r, env_w, 0.10)
|
||
r_hi = _rolling_quantile(r, env_w, 0.90)
|
||
g_lo = _rolling_quantile(g, env_w, 0.10)
|
||
g_hi = _rolling_quantile(g, env_w, 0.90)
|
||
|
||
ax = fig.add_subplot(gs[idx // ncols, idx % ncols])
|
||
|
||
# faint raw “watermark”
|
||
ax.plot(t, r, linewidth=0.6, alpha=0.18, label="real (raw)" if idx == 0 else None)
|
||
ax.plot(t, g, linewidth=0.6, alpha=0.18, label="gen (raw)" if idx == 0 else None)
|
||
|
||
# envelopes
|
||
ax.fill_between(t, r_lo, r_hi, alpha=0.18, label="real (10–90%)" if idx == 0 else None)
|
||
ax.fill_between(t, g_lo, g_hi, alpha=0.18, label="gen (10–90%)" if idx == 0 else None)
|
||
|
||
# bold trends
|
||
ax.plot(t, r_tr, linewidth=2.0, alpha=0.9, label="real (trend)" if idx == 0 else None)
|
||
ax.plot(t, g_tr, linewidth=2.0, alpha=0.9, label="gen (trend)" if idx == 0 else None)
|
||
|
||
# title with metric
|
||
mval = float(dfc.loc[dfc["feature"] == feat, rank_by].iloc[0])
|
||
ax.set_title(f"{feat} | {rank_by}={mval:.3f}", fontsize=10)
|
||
|
||
# cosmetics
|
||
ax.grid(axis="y", linestyle="--", alpha=0.25)
|
||
ax.tick_params(labelsize=8)
|
||
|
||
# shared legend (single)
|
||
handles, labels = fig.axes[0].get_legend_handles_labels()
|
||
fig.legend(handles, labels, loc="upper center", ncol=3, frameon=False, bbox_to_anchor=(0.5, 0.99))
|
||
fig.suptitle("Trend comparison (continuous features): raw watermark + smoothed trend + local envelope",
|
||
fontsize=14, y=1.02)
|
||
|
||
fig.tight_layout()
|
||
fig.savefig(os.path.join(outdir, f"trend_small_multiples_{rank_by}.png"), dpi=300, bbox_inches="tight")
|
||
plt.close(fig)
|
||
|
||
|
||
# ----------------------------
|
||
# 4) discrete trends
|
||
# ----------------------------
|
||
|
||
|
||
def plot_discrete_state_trends(
|
||
real_df: pd.DataFrame,
|
||
gen_df: pd.DataFrame,
|
||
metrics_df: pd.DataFrame,
|
||
feature_types: Dict[str, str],
|
||
disc_vocab_map: Dict[str, List[str]] | None,
|
||
outdir: str,
|
||
k: int = 6,
|
||
rank_by: str = "jsd",
|
||
win: int = 400,
|
||
step: int = 40,
|
||
) -> None:
|
||
"""
|
||
For top-k discrete features by JSD, plot rolling state occupancy over time:
|
||
p_t(v) = fraction of state v in window centered at t
|
||
"""
|
||
_ensure_dir(outdir)
|
||
|
||
dfd = metrics_df[metrics_df["ftype"] == "discrete"].dropna(subset=[rank_by]).copy()
|
||
dfd = dfd.sort_values(rank_by, ascending=False).head(k)
|
||
feats = dfd["feature"].tolist()
|
||
if not feats:
|
||
print("[trend] No discrete features found to plot.")
|
||
return
|
||
|
||
for feat in feats:
|
||
r = real_df[feat].to_numpy()
|
||
g = gen_df[feat].to_numpy()
|
||
vocab = disc_vocab_map.get(feat) if disc_vocab_map else sorted(list(set(r).union(set(g))))
|
||
|
||
# timeline centers
|
||
centers = np.arange(win // 2, len(r) - win // 2, step, dtype=np.int64)
|
||
t = centers
|
||
|
||
# occupancy matrices: shape [len(vocab), len(t)]
|
||
R = np.zeros((len(vocab), len(t)), dtype=np.float64)
|
||
G = np.zeros((len(vocab), len(t)), dtype=np.float64)
|
||
|
||
for ti, c in enumerate(centers):
|
||
lo = c - win // 2
|
||
hi = c + win // 2
|
||
rw = r[lo:hi]
|
||
gw = g[lo:hi]
|
||
|
||
# counts
|
||
rc = pd.Series(rw).value_counts()
|
||
gc = pd.Series(gw).value_counts()
|
||
for vi, v in enumerate(vocab):
|
||
R[vi, ti] = rc.get(v, 0) / max(1, len(rw))
|
||
G[vi, ti] = gc.get(v, 0) / max(1, len(gw))
|
||
|
||
# Plot as two stacked areas: Real vs Gen (same vocab order)
|
||
fig, ax = plt.subplots(figsize=(11.5, 4.8))
|
||
ax.stackplot(t, R, alpha=0.35, labels=[f"{v} (real)" for v in vocab])
|
||
ax.stackplot(t, G, alpha=0.35, labels=[f"{v} (gen)" for v in vocab])
|
||
|
||
mval = float(dfd.loc[dfd["feature"] == feat, rank_by].iloc[0])
|
||
ax.set_title(f"Discrete trend via rolling state occupancy — {feat} | {rank_by}={mval:.3f}", fontsize=13, pad=10)
|
||
ax.set_xlabel("time index (window centers)")
|
||
ax.set_ylabel("occupancy probability")
|
||
ax.grid(axis="y", linestyle="--", alpha=0.25)
|
||
|
||
# Put legend outside to keep plot clean
|
||
ax.legend(loc="center left", bbox_to_anchor=(1.01, 0.5), frameon=False, fontsize=8)
|
||
|
||
fig.tight_layout()
|
||
fig.savefig(os.path.join(outdir, f"disc_state_trend_{feat}.png"), dpi=300, bbox_inches="tight")
|
||
plt.close(fig)
|
||
|
||
|
||
|
||
# ----------------------------
|
||
# 4) Main entry
|
||
# ----------------------------
|
||
|
||
def main() -> None:
|
||
outdir = "benchmark_figs"
|
||
real_df, gen_df, feature_types, disc_vocab_map = make_dummy_real_gen(
|
||
T=10000, n_cont=20, n_disc=8, seed=42
|
||
)
|
||
|
||
metrics_df = compute_metrics(real_df, gen_df, feature_types, disc_vocab_map)
|
||
metrics_df.to_csv(os.path.join(outdir, "metrics_per_feature.csv"), index=False)
|
||
|
||
plot_heatmap_metrics(metrics_df, outdir)
|
||
plot_metric_distributions(metrics_df, outdir)
|
||
plot_overall_summary(metrics_df, outdir)
|
||
|
||
# Top-k charts: choose what you emphasize
|
||
plot_topk_lollipop(metrics_df, metric="ks", outdir=outdir, k=15)
|
||
plot_topk_lollipop(metrics_df, metric="jsd", outdir=outdir, k=15)
|
||
plot_topk_lollipop(metrics_df, metric="lag1_diff", outdir=outdir, k=15)
|
||
|
||
# Fancy trend figures (new)
|
||
plot_trend_small_multiples(
|
||
real_df, gen_df, metrics_df, feature_types, outdir,
|
||
k=12, rank_by="lag1_diff", smooth_w=250, env_w=800, downsample=1
|
||
)
|
||
|
||
plot_trend_small_multiples(
|
||
real_df, gen_df, metrics_df, feature_types, outdir,
|
||
k=12, rank_by="ks", smooth_w=250, env_w=800, downsample=1
|
||
)
|
||
|
||
plot_discrete_state_trends(
|
||
real_df, gen_df, metrics_df, feature_types, disc_vocab_map, outdir,
|
||
k=6, rank_by="jsd", win=500, step=50
|
||
)
|
||
|
||
print(f"Saved figures + CSV under: ./{outdir}/")
|
||
|
||
if __name__ == "__main__":
|
||
main()
|