Files
mask-ddpm/visualization/vis_benchmark.py
2026-02-02 16:35:23 +08:00

573 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Benchmark visualization for per-feature KS / JSD / lag-1 diff + overall averages.
- Uses dummy data generators (AR(1) for continuous; Markov chain for discrete)
- Replace `real_df` and `gen_df` with your real/generated samples.
- Outputs publication-quality figures into ./benchmark_figs/
Dependencies:
pip install numpy pandas matplotlib scipy
(Optional for nicer label wrapping)
pip install textwrap3
"""
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import ks_2samp
from scipy.spatial.distance import jensenshannon
# ----------------------------
# 1) Dummy data generation
# ----------------------------
def ar1_series(T: int, phi: float, sigma: float, mu: float = 0.0, seed: int | None = None) -> np.ndarray:
rng = np.random.default_rng(seed)
x = np.zeros(T, dtype=np.float64)
eps = rng.normal(0.0, sigma, size=T)
x[0] = mu + eps[0]
for t in range(1, T):
x[t] = mu + phi * (x[t - 1] - mu) + eps[t]
return x
def markov_chain(T: int, P: np.ndarray, pi0: np.ndarray, seed: int | None = None) -> np.ndarray:
"""Return integer states 0..K-1."""
rng = np.random.default_rng(seed)
K = P.shape[0]
states = np.zeros(T, dtype=np.int64)
states[0] = rng.choice(K, p=pi0)
for t in range(1, T):
states[t] = rng.choice(K, p=P[states[t - 1]])
return states
def make_dummy_real_gen(
T: int = 8000,
n_cont: int = 18,
n_disc: int = 6,
disc_vocab_sizes: List[int] | None = None,
seed: int = 7,
) -> Tuple[pd.DataFrame, pd.DataFrame, Dict[str, str], Dict[str, List[str]]]:
"""
Returns:
real_df, gen_df, feature_types, disc_vocab_map
where:
feature_types[col] in {"continuous","discrete"}
disc_vocab_map[col] = list of token strings (for discrete vars)
"""
rng = np.random.default_rng(seed)
cont_cols = [f"c{i:02d}" for i in range(n_cont)]
disc_cols = [f"d{i:02d}" for i in range(n_disc)]
if disc_vocab_sizes is None:
base = [3, 4, 5, 3, 6, 4]
disc_vocab_sizes = (base + [4] * max(0, n_disc - len(base)))[:n_disc]
elif len(disc_vocab_sizes) < n_disc:
# If you pass a shorter list, extend it safely
disc_vocab_sizes = list(disc_vocab_sizes) + [disc_vocab_sizes[-1]] * (n_disc - len(disc_vocab_sizes))
real = {}
gen = {}
# Continuous: AR(1) with slightly different parameters for generated
for i, col in enumerate(cont_cols):
phi_real = rng.uniform(0.75, 0.98)
sigma_real = rng.uniform(0.3, 1.2)
mu_real = rng.uniform(-1.0, 1.0)
# Generated is close but not identical
phi_gen = np.clip(phi_real + rng.normal(0, 0.02), 0.60, 0.995)
sigma_gen = max(0.05, sigma_real * rng.uniform(0.85, 1.15))
mu_gen = mu_real + rng.normal(0, 0.15)
real[col] = ar1_series(T, phi_real, sigma_real, mu_real, seed=seed + 100 + i)
gen[col] = ar1_series(T, phi_gen, sigma_gen, mu_gen, seed=seed + 200 + i)
# Discrete: Markov chains with slightly perturbed transition matrices
disc_vocab_map: Dict[str, List[str]] = {}
for j, col in enumerate(disc_cols):
K = disc_vocab_sizes[j]
disc_vocab_map[col] = [f"s{k}" for k in range(K)]
# Random but stable transition matrix for real
A = rng.uniform(0.1, 1.0, size=(K, K))
P_real = A / A.sum(axis=1, keepdims=True)
pi0 = rng.uniform(0.1, 1.0, size=K)
pi0 = pi0 / pi0.sum()
# Perturb for generated
noise = rng.normal(0, 0.06, size=(K, K))
P_gen = np.clip(P_real + noise, 1e-6, None)
P_gen = P_gen / P_gen.sum(axis=1, keepdims=True)
real_states = markov_chain(T, P_real, pi0, seed=seed + 300 + j)
gen_states = markov_chain(T, P_gen, pi0, seed=seed + 400 + j)
# Store as tokens (strings) to emphasize discrete nature
real[col] = np.array([disc_vocab_map[col][s] for s in real_states], dtype=object)
gen[col] = np.array([disc_vocab_map[col][s] for s in gen_states], dtype=object)
real_df = pd.DataFrame(real)
gen_df = pd.DataFrame(gen)
feature_types = {**{c: "continuous" for c in cont_cols}, **{d: "discrete" for d in disc_cols}}
return real_df, gen_df, feature_types, disc_vocab_map
# ----------------------------
# 2) Metric computation
# ----------------------------
def lag1_autocorr(x: np.ndarray) -> float:
"""Lag-1 Pearson autocorrelation (robust to constant arrays)."""
x = np.asarray(x, dtype=np.float64)
if x.size < 3:
return np.nan
x0 = x[:-1]
x1 = x[1:]
s0 = np.std(x0)
s1 = np.std(x1)
if s0 < 1e-12 or s1 < 1e-12:
return 0.0
return float(np.corrcoef(x0, x1)[0, 1])
def jsd_discrete(real_tokens: np.ndarray, gen_tokens: np.ndarray, vocab: List[str], base: float = 2.0) -> float:
"""JensenShannon divergence for discrete distributions over vocab."""
# empirical pmf with Laplace smoothing to avoid zeros
eps = 1e-12
r = pd.Series(real_tokens).value_counts()
g = pd.Series(gen_tokens).value_counts()
p = np.array([r.get(v, 0) for v in vocab], dtype=np.float64)
q = np.array([g.get(v, 0) for v in vocab], dtype=np.float64)
p = p + eps
q = q + eps
p = p / p.sum()
q = q / q.sum()
# scipy returns sqrt(JS); square it to get JS divergence
return float(jensenshannon(p, q, base=base) ** 2)
@dataclass
class Metrics:
feature: str
ftype: str
ks: float | np.nan
jsd: float | np.nan
lag1_diff: float | np.nan
def compute_metrics(
real_df: pd.DataFrame,
gen_df: pd.DataFrame,
feature_types: Dict[str, str],
disc_vocab_map: Dict[str, List[str]] | None = None,
) -> pd.DataFrame:
rows: List[Metrics] = []
for col in real_df.columns:
ftype = feature_types[col]
r = real_df[col].to_numpy()
g = gen_df[col].to_numpy()
if ftype == "continuous":
# KS on raw values
ks_val = float(ks_2samp(r.astype(np.float64), g.astype(np.float64), alternative="two-sided").statistic)
# lag-1 diff
lag_r = lag1_autocorr(r)
lag_g = lag1_autocorr(g)
lag_diff = float(abs(lag_r - lag_g))
rows.append(Metrics(col, ftype, ks_val, np.nan, lag_diff))
elif ftype == "discrete":
# JSD on categorical pmf
if disc_vocab_map is None or col not in disc_vocab_map:
vocab = sorted(list(set(r).union(set(g))))
else:
vocab = disc_vocab_map[col]
jsd_val = jsd_discrete(r, g, vocab=vocab, base=2.0)
# optional: lag1 diff on integer encoding (captures “stickiness”; still interpretable)
# if you prefer lag1 only for continuous, set lag_diff=np.nan here.
mapping = {tok: i for i, tok in enumerate(vocab)}
r_int = np.array([mapping[t] for t in r], dtype=np.float64)
g_int = np.array([mapping[t] for t in g], dtype=np.float64)
lag_r = lag1_autocorr(r_int)
lag_g = lag1_autocorr(g_int)
lag_diff = float(abs(lag_r - lag_g))
rows.append(Metrics(col, ftype, np.nan, jsd_val, lag_diff))
else:
raise ValueError(f"Unknown feature type: {ftype}")
df = pd.DataFrame([r.__dict__ for r in rows])
# Overall (separate by metric meaning)
overall = {
"ks_avg_cont": float(df.loc[df.ftype == "continuous", "ks"].mean(skipna=True)),
"jsd_avg_disc": float(df.loc[df.ftype == "discrete", "jsd"].mean(skipna=True)),
"lag1_avg_all": float(df["lag1_diff"].mean(skipna=True)),
}
print("Overall averages:", overall)
return df
# ----------------------------
# 3) Fancy but honest visualizations
# ----------------------------
def _ensure_dir(p: str) -> None:
os.makedirs(p, exist_ok=True)
def plot_heatmap_metrics(metrics_df: pd.DataFrame, outdir: str, title: str = "Per-feature benchmark (normalized)") -> None:
"""
Heatmap of per-feature metrics.
We normalize each metric column to [0,1] so KS/JSD/lag1 are comparable visually.
"""
_ensure_dir(outdir)
df = metrics_df.copy()
# Create a unified table with columns present for all features
# For missing metric values (e.g., ks for discrete), keep NaN; normalize per metric ignoring NaNs.
show_cols = ["ks", "jsd", "lag1_diff"]
# sort features by "worst" normalized score (max across available metrics)
norm = df[show_cols].copy()
for c in show_cols:
col = norm[c]
if col.notna().sum() == 0:
continue
mn = float(col.min(skipna=True))
mx = float(col.max(skipna=True))
if abs(mx - mn) < 1e-12:
norm[c] = 0.0
else:
norm[c] = (col - mn) / (mx - mn)
df["_severity"] = norm.max(axis=1, skipna=True)
df = df.sort_values("_severity", ascending=False).drop(columns=["_severity"])
norm = norm.loc[df.index]
# Plot heatmap with matplotlib (no seaborn)
fig, ax = plt.subplots(figsize=(9.5, max(4.0, 0.28 * len(df))))
im = ax.imshow(norm.to_numpy(), aspect="auto", interpolation="nearest")
ax.set_yticks(np.arange(len(df)))
ax.set_yticklabels(df["feature"].tolist(), fontsize=8)
ax.set_xticks(np.arange(len(show_cols)))
ax.set_xticklabels(show_cols, fontsize=10)
ax.set_title(title, fontsize=13, pad=12)
cbar = fig.colorbar(im, ax=ax, fraction=0.035, pad=0.02)
cbar.set_label("minmax normalized (per metric)", rotation=90)
# Light gridlines for readability
ax.set_xticks(np.arange(-.5, len(show_cols), 1), minor=True)
ax.set_yticks(np.arange(-.5, len(df), 1), minor=True)
ax.grid(which="minor", color="white", linestyle="-", linewidth=0.6)
ax.tick_params(which="minor", bottom=False, left=False)
fig.tight_layout()
fig.savefig(os.path.join(outdir, "heatmap_metrics_normalized.png"), dpi=300)
plt.close(fig)
def plot_topk_lollipop(metrics_df: pd.DataFrame, metric: str, outdir: str, k: int = 15) -> None:
"""Fancy top-k lollipop chart for a metric (e.g., 'ks' or 'jsd' or 'lag1_diff')."""
_ensure_dir(outdir)
df = metrics_df.dropna(subset=[metric]).sort_values(metric, ascending=False).head(k)
fig, ax = plt.subplots(figsize=(9.5, 0.45 * len(df) + 1.5))
y = np.arange(len(df))[::-1]
x = df[metric].to_numpy()
# stems
ax.hlines(y=y, xmin=0, xmax=x, linewidth=2.0, alpha=0.85)
# heads
ax.plot(x, y, "o", markersize=7)
ax.set_yticks(y)
ax.set_yticklabels(df["feature"].tolist(), fontsize=9)
ax.set_xlabel(metric)
ax.set_title(f"Top-{k} features by {metric}", fontsize=13, pad=10)
# subtle x-grid
ax.grid(axis="x", linestyle="--", alpha=0.4)
fig.tight_layout()
fig.savefig(os.path.join(outdir, f"topk_{metric}_lollipop.png"), dpi=300)
plt.close(fig)
def plot_metric_distributions(metrics_df: pd.DataFrame, outdir: str) -> None:
"""Histogram distributions of each metric across features."""
_ensure_dir(outdir)
metrics = ["ks", "jsd", "lag1_diff"]
fig, ax = plt.subplots(figsize=(9.5, 5.2))
# overlay histograms (alpha) for compact comparison
for m in metrics:
vals = metrics_df[m].dropna().to_numpy()
if len(vals) == 0:
continue
ax.hist(vals, bins=20, alpha=0.55, density=True, label=m)
ax.set_title("Distribution of per-feature benchmark metrics", fontsize=13, pad=10)
ax.set_xlabel("metric value")
ax.set_ylabel("density")
ax.grid(axis="y", linestyle="--", alpha=0.35)
ax.legend()
fig.tight_layout()
fig.savefig(os.path.join(outdir, "metric_distributions.png"), dpi=300)
plt.close(fig)
def plot_overall_summary(metrics_df: pd.DataFrame, outdir: str) -> None:
"""Simple overall averages (separate meaning per metric)."""
_ensure_dir(outdir)
ks_avg = metrics_df.loc[metrics_df.ftype == "continuous", "ks"].mean(skipna=True)
jsd_avg = metrics_df.loc[metrics_df.ftype == "discrete", "jsd"].mean(skipna=True)
lag_avg = metrics_df["lag1_diff"].mean(skipna=True)
labels = ["KS(avg cont)", "JSD(avg disc)", "Lag1 diff(avg)"]
values = [ks_avg, jsd_avg, lag_avg]
fig, ax = plt.subplots(figsize=(7.8, 4.6))
ax.bar(labels, values)
ax.set_title("Overall benchmark summary", fontsize=13, pad=10)
ax.set_ylabel("value")
ax.grid(axis="y", linestyle="--", alpha=0.35)
fig.tight_layout()
fig.savefig(os.path.join(outdir, "overall_summary_bar.png"), dpi=300)
plt.close(fig)
# ----------------------------
# 4) continuous trends
# ----------------------------
def _smooth_ma(x: np.ndarray, w: int) -> np.ndarray:
"""Simple moving average smoothing; w>=1."""
if w <= 1:
return x.astype(np.float64, copy=True)
s = pd.Series(x.astype(np.float64))
return s.rolling(w, center=True, min_periods=max(2, w // 10)).mean().to_numpy()
def _rolling_quantile(x: np.ndarray, w: int, q: float) -> np.ndarray:
"""Rolling quantile for a local envelope."""
if w <= 1:
return x.astype(np.float64, copy=True)
s = pd.Series(x.astype(np.float64))
return s.rolling(w, center=True, min_periods=max(10, w // 5)).quantile(q).to_numpy()
def plot_trend_small_multiples(
real_df: pd.DataFrame,
gen_df: pd.DataFrame,
metrics_df: pd.DataFrame,
feature_types: Dict[str, str],
outdir: str,
k: int = 12,
rank_by: str = "lag1_diff", # or "ks"
smooth_w: int = 200,
env_w: int = 600,
downsample: int = 1,
) -> None:
"""
Small multiples of continuous-feature trends.
- Picks top-k continuous features by rank_by (ks or lag1_diff).
- Plots: faint raw + bold smoothed + rolling 1090% envelope for real and gen.
"""
_ensure_dir(outdir)
# pick top-k continuous features
dfc = metrics_df[metrics_df["ftype"] == "continuous"].dropna(subset=[rank_by]).copy()
dfc = dfc.sort_values(rank_by, ascending=False).head(k)
feats = dfc["feature"].tolist()
if not feats:
print("[trend] No continuous features found to plot.")
return
n = len(feats)
ncols = 3
nrows = int(np.ceil(n / ncols))
fig = plt.subplots(figsize=(12, 3.2 * nrows))[0]
gs = fig.add_gridspec(nrows, ncols, hspace=0.35, wspace=0.22)
for idx, feat in enumerate(feats):
r = real_df[feat].to_numpy(dtype=np.float64)
g = gen_df[feat].to_numpy(dtype=np.float64)
if downsample > 1:
r = r[::downsample]
g = g[::downsample]
t = np.arange(len(r), dtype=np.int64)
# smoothed trend
r_tr = _smooth_ma(r, smooth_w)
g_tr = _smooth_ma(g, smooth_w)
# local envelope (1090%)
r_lo = _rolling_quantile(r, env_w, 0.10)
r_hi = _rolling_quantile(r, env_w, 0.90)
g_lo = _rolling_quantile(g, env_w, 0.10)
g_hi = _rolling_quantile(g, env_w, 0.90)
ax = fig.add_subplot(gs[idx // ncols, idx % ncols])
# faint raw “watermark”
ax.plot(t, r, linewidth=0.6, alpha=0.18, label="real (raw)" if idx == 0 else None)
ax.plot(t, g, linewidth=0.6, alpha=0.18, label="gen (raw)" if idx == 0 else None)
# envelopes
ax.fill_between(t, r_lo, r_hi, alpha=0.18, label="real (1090%)" if idx == 0 else None)
ax.fill_between(t, g_lo, g_hi, alpha=0.18, label="gen (1090%)" if idx == 0 else None)
# bold trends
ax.plot(t, r_tr, linewidth=2.0, alpha=0.9, label="real (trend)" if idx == 0 else None)
ax.plot(t, g_tr, linewidth=2.0, alpha=0.9, label="gen (trend)" if idx == 0 else None)
# title with metric
mval = float(dfc.loc[dfc["feature"] == feat, rank_by].iloc[0])
ax.set_title(f"{feat} | {rank_by}={mval:.3f}", fontsize=10)
# cosmetics
ax.grid(axis="y", linestyle="--", alpha=0.25)
ax.tick_params(labelsize=8)
# shared legend (single)
handles, labels = fig.axes[0].get_legend_handles_labels()
fig.legend(handles, labels, loc="upper center", ncol=3, frameon=False, bbox_to_anchor=(0.5, 0.99))
fig.suptitle("Trend comparison (continuous features): raw watermark + smoothed trend + local envelope",
fontsize=14, y=1.02)
fig.tight_layout()
fig.savefig(os.path.join(outdir, f"trend_small_multiples_{rank_by}.png"), dpi=300, bbox_inches="tight")
plt.close(fig)
# ----------------------------
# 4) discrete trends
# ----------------------------
def plot_discrete_state_trends(
real_df: pd.DataFrame,
gen_df: pd.DataFrame,
metrics_df: pd.DataFrame,
feature_types: Dict[str, str],
disc_vocab_map: Dict[str, List[str]] | None,
outdir: str,
k: int = 6,
rank_by: str = "jsd",
win: int = 400,
step: int = 40,
) -> None:
"""
For top-k discrete features by JSD, plot rolling state occupancy over time:
p_t(v) = fraction of state v in window centered at t
"""
_ensure_dir(outdir)
dfd = metrics_df[metrics_df["ftype"] == "discrete"].dropna(subset=[rank_by]).copy()
dfd = dfd.sort_values(rank_by, ascending=False).head(k)
feats = dfd["feature"].tolist()
if not feats:
print("[trend] No discrete features found to plot.")
return
for feat in feats:
r = real_df[feat].to_numpy()
g = gen_df[feat].to_numpy()
vocab = disc_vocab_map.get(feat) if disc_vocab_map else sorted(list(set(r).union(set(g))))
# timeline centers
centers = np.arange(win // 2, len(r) - win // 2, step, dtype=np.int64)
t = centers
# occupancy matrices: shape [len(vocab), len(t)]
R = np.zeros((len(vocab), len(t)), dtype=np.float64)
G = np.zeros((len(vocab), len(t)), dtype=np.float64)
for ti, c in enumerate(centers):
lo = c - win // 2
hi = c + win // 2
rw = r[lo:hi]
gw = g[lo:hi]
# counts
rc = pd.Series(rw).value_counts()
gc = pd.Series(gw).value_counts()
for vi, v in enumerate(vocab):
R[vi, ti] = rc.get(v, 0) / max(1, len(rw))
G[vi, ti] = gc.get(v, 0) / max(1, len(gw))
# Plot as two stacked areas: Real vs Gen (same vocab order)
fig, ax = plt.subplots(figsize=(11.5, 4.8))
ax.stackplot(t, R, alpha=0.35, labels=[f"{v} (real)" for v in vocab])
ax.stackplot(t, G, alpha=0.35, labels=[f"{v} (gen)" for v in vocab])
mval = float(dfd.loc[dfd["feature"] == feat, rank_by].iloc[0])
ax.set_title(f"Discrete trend via rolling state occupancy — {feat} | {rank_by}={mval:.3f}", fontsize=13, pad=10)
ax.set_xlabel("time index (window centers)")
ax.set_ylabel("occupancy probability")
ax.grid(axis="y", linestyle="--", alpha=0.25)
# Put legend outside to keep plot clean
ax.legend(loc="center left", bbox_to_anchor=(1.01, 0.5), frameon=False, fontsize=8)
fig.tight_layout()
fig.savefig(os.path.join(outdir, f"disc_state_trend_{feat}.png"), dpi=300, bbox_inches="tight")
plt.close(fig)
# ----------------------------
# 4) Main entry
# ----------------------------
def main() -> None:
outdir = "benchmark_figs"
_ensure_dir(outdir)
real_df, gen_df, feature_types, disc_vocab_map = make_dummy_real_gen(
T=10000, n_cont=20, n_disc=8, seed=42
)
metrics_df = compute_metrics(real_df, gen_df, feature_types, disc_vocab_map)
metrics_df.to_csv(os.path.join(outdir, "metrics_per_feature.csv"), index=False)
plot_heatmap_metrics(metrics_df, outdir)
plot_metric_distributions(metrics_df, outdir)
plot_overall_summary(metrics_df, outdir)
# Top-k charts: choose what you emphasize
plot_topk_lollipop(metrics_df, metric="ks", outdir=outdir, k=15)
plot_topk_lollipop(metrics_df, metric="jsd", outdir=outdir, k=15)
plot_topk_lollipop(metrics_df, metric="lag1_diff", outdir=outdir, k=15)
# Fancy trend figures (new)
plot_trend_small_multiples(
real_df, gen_df, metrics_df, feature_types, outdir,
k=12, rank_by="lag1_diff", smooth_w=250, env_w=800, downsample=1
)
plot_trend_small_multiples(
real_df, gen_df, metrics_df, feature_types, outdir,
k=12, rank_by="ks", smooth_w=250, env_w=800, downsample=1
)
plot_discrete_state_trends(
real_df, gen_df, metrics_df, feature_types, disc_vocab_map, outdir,
k=6, rank_by="jsd", win=500, step=50
)
print(f"Saved figures + CSV under: ./{outdir}/")
if __name__ == "__main__":
main()