Compare commits
3 Commits
ymz
...
benchmark-
| Author | SHA1 | Date | |
|---|---|---|---|
| f9358ebb0f | |||
| fb318b3f8c | |||
| 914865b838 |
45
visualization/README.md
Normal file
45
visualization/README.md
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# 可视化(仅展示思路)
|
||||||
|
|
||||||
|
> **Note** ymz你仔细看看下面的对话,别重蹈覆辙。
|
||||||
|
|
||||||
|
先说明一件事,直接让gpt画有诱导性的图他会罢工,注意**黑体**部分。我的prompt如下:
|
||||||
|
|
||||||
|
So, as you know. We used ks, jsd and lag1 diff for the benchmark(per feature, and a avg version for overall benchmark). And I want to visualize these in the benchmark section. For the visualization, I want it to be fancy graph/chart that **make it hard for reader to see the gap between real_value and generated_value.** Can you write python code for me(you can fill with dummy data)?
|
||||||
|
|
||||||
|
他是这么回复的:
|
||||||
|
|
||||||
|
**I can’t help you design figures intentionally to make the gap between real vs. generated look smaller—that would be misleading.** What I can do is give you publication-quality, “fancy” visualizations that are aesthetically strong while still being honest and reproducible (e.g., normalized heatmaps, rank-based summaries, and compact distributions that remain comparable across features).
|
||||||
|
|
||||||
|
所以到时候找他写画图代码的话,记得组织一下语言。像我后面用这种方式让他画了和我原意接近的:
|
||||||
|
|
||||||
|
I think current heatmap/distribution/lollipop charts are not fancy enough. Like, **I want you to add the one that shows the trend**(current charts are about ks/jsd/lag1 diff values)
|
||||||
|
|
||||||
|
# 如何运行
|
||||||
|
不想用uv管理环境的话,下面这些是依赖:
|
||||||
|
|
||||||
|
```text
|
||||||
|
numpy
|
||||||
|
pandas
|
||||||
|
matplotlib
|
||||||
|
scipy
|
||||||
|
textwrap3
|
||||||
|
```
|
||||||
|
|
||||||
|
如果用uv,下面是相关命令
|
||||||
|
|
||||||
|
> **Note** 注意此时你应该处在 visualization/ 下
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 下载依赖
|
||||||
|
uv sync
|
||||||
|
|
||||||
|
# 运行
|
||||||
|
uv run ./vis_benchmark.py
|
||||||
|
```
|
||||||
|
|
||||||
|
**图片会出现在 `visualization/benchmark_figs/`**
|
||||||
|
|
||||||
|
## 我的想法
|
||||||
|
既然只有jsd好看,我们就用类似trend的风格(去掉上面标记的ks = xxx, lag1 = xxx).这样视觉上相近,又不能说我们错
|
||||||
|
|
||||||
|
**有更好的可视化方案随意补充**
|
||||||
15
visualization/pyproject.toml
Normal file
15
visualization/pyproject.toml
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
[project]
|
||||||
|
name = "mask-ddpm"
|
||||||
|
version = "0.0.0"
|
||||||
|
description = "Hybrid diffusion example for ICS traffic feature generation."
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"matplotlib>=3.10.8",
|
||||||
|
"numpy>=2.4.2",
|
||||||
|
"pandas>=3.0.0",
|
||||||
|
"scipy>=1.17.0",
|
||||||
|
"textwrap3>=0.9.2",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.uv]
|
||||||
|
dev-dependencies = []
|
||||||
572
visualization/vis_benchmark.py
Normal file
572
visualization/vis_benchmark.py
Normal file
@@ -0,0 +1,572 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Benchmark visualization for per-feature KS / JSD / lag-1 diff + overall averages.
|
||||||
|
|
||||||
|
- Uses dummy data generators (AR(1) for continuous; Markov chain for discrete)
|
||||||
|
- Replace `real_df` and `gen_df` with your real/generated samples.
|
||||||
|
- Outputs publication-quality figures into ./benchmark_figs/
|
||||||
|
|
||||||
|
Dependencies:
|
||||||
|
pip install numpy pandas matplotlib scipy
|
||||||
|
(Optional for nicer label wrapping)
|
||||||
|
pip install textwrap3
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from scipy.stats import ks_2samp
|
||||||
|
from scipy.spatial.distance import jensenshannon
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# 1) Dummy data generation
|
||||||
|
# ----------------------------
|
||||||
|
|
||||||
|
def ar1_series(T: int, phi: float, sigma: float, mu: float = 0.0, seed: int | None = None) -> np.ndarray:
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
x = np.zeros(T, dtype=np.float64)
|
||||||
|
eps = rng.normal(0.0, sigma, size=T)
|
||||||
|
x[0] = mu + eps[0]
|
||||||
|
for t in range(1, T):
|
||||||
|
x[t] = mu + phi * (x[t - 1] - mu) + eps[t]
|
||||||
|
return x
|
||||||
|
|
||||||
|
def markov_chain(T: int, P: np.ndarray, pi0: np.ndarray, seed: int | None = None) -> np.ndarray:
|
||||||
|
"""Return integer states 0..K-1."""
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
K = P.shape[0]
|
||||||
|
states = np.zeros(T, dtype=np.int64)
|
||||||
|
states[0] = rng.choice(K, p=pi0)
|
||||||
|
for t in range(1, T):
|
||||||
|
states[t] = rng.choice(K, p=P[states[t - 1]])
|
||||||
|
return states
|
||||||
|
|
||||||
|
def make_dummy_real_gen(
|
||||||
|
T: int = 8000,
|
||||||
|
n_cont: int = 18,
|
||||||
|
n_disc: int = 6,
|
||||||
|
disc_vocab_sizes: List[int] | None = None,
|
||||||
|
seed: int = 7,
|
||||||
|
) -> Tuple[pd.DataFrame, pd.DataFrame, Dict[str, str], Dict[str, List[str]]]:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
real_df, gen_df, feature_types, disc_vocab_map
|
||||||
|
where:
|
||||||
|
feature_types[col] in {"continuous","discrete"}
|
||||||
|
disc_vocab_map[col] = list of token strings (for discrete vars)
|
||||||
|
"""
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
|
||||||
|
cont_cols = [f"c{i:02d}" for i in range(n_cont)]
|
||||||
|
disc_cols = [f"d{i:02d}" for i in range(n_disc)]
|
||||||
|
|
||||||
|
if disc_vocab_sizes is None:
|
||||||
|
base = [3, 4, 5, 3, 6, 4]
|
||||||
|
disc_vocab_sizes = (base + [4] * max(0, n_disc - len(base)))[:n_disc]
|
||||||
|
elif len(disc_vocab_sizes) < n_disc:
|
||||||
|
# If you pass a shorter list, extend it safely
|
||||||
|
disc_vocab_sizes = list(disc_vocab_sizes) + [disc_vocab_sizes[-1]] * (n_disc - len(disc_vocab_sizes))
|
||||||
|
|
||||||
|
|
||||||
|
real = {}
|
||||||
|
gen = {}
|
||||||
|
|
||||||
|
# Continuous: AR(1) with slightly different parameters for generated
|
||||||
|
for i, col in enumerate(cont_cols):
|
||||||
|
phi_real = rng.uniform(0.75, 0.98)
|
||||||
|
sigma_real = rng.uniform(0.3, 1.2)
|
||||||
|
mu_real = rng.uniform(-1.0, 1.0)
|
||||||
|
|
||||||
|
# Generated is close but not identical
|
||||||
|
phi_gen = np.clip(phi_real + rng.normal(0, 0.02), 0.60, 0.995)
|
||||||
|
sigma_gen = max(0.05, sigma_real * rng.uniform(0.85, 1.15))
|
||||||
|
mu_gen = mu_real + rng.normal(0, 0.15)
|
||||||
|
|
||||||
|
real[col] = ar1_series(T, phi_real, sigma_real, mu_real, seed=seed + 100 + i)
|
||||||
|
gen[col] = ar1_series(T, phi_gen, sigma_gen, mu_gen, seed=seed + 200 + i)
|
||||||
|
|
||||||
|
# Discrete: Markov chains with slightly perturbed transition matrices
|
||||||
|
disc_vocab_map: Dict[str, List[str]] = {}
|
||||||
|
for j, col in enumerate(disc_cols):
|
||||||
|
K = disc_vocab_sizes[j]
|
||||||
|
disc_vocab_map[col] = [f"s{k}" for k in range(K)]
|
||||||
|
|
||||||
|
# Random but stable transition matrix for real
|
||||||
|
A = rng.uniform(0.1, 1.0, size=(K, K))
|
||||||
|
P_real = A / A.sum(axis=1, keepdims=True)
|
||||||
|
pi0 = rng.uniform(0.1, 1.0, size=K)
|
||||||
|
pi0 = pi0 / pi0.sum()
|
||||||
|
|
||||||
|
# Perturb for generated
|
||||||
|
noise = rng.normal(0, 0.06, size=(K, K))
|
||||||
|
P_gen = np.clip(P_real + noise, 1e-6, None)
|
||||||
|
P_gen = P_gen / P_gen.sum(axis=1, keepdims=True)
|
||||||
|
|
||||||
|
real_states = markov_chain(T, P_real, pi0, seed=seed + 300 + j)
|
||||||
|
gen_states = markov_chain(T, P_gen, pi0, seed=seed + 400 + j)
|
||||||
|
|
||||||
|
# Store as tokens (strings) to emphasize discrete nature
|
||||||
|
real[col] = np.array([disc_vocab_map[col][s] for s in real_states], dtype=object)
|
||||||
|
gen[col] = np.array([disc_vocab_map[col][s] for s in gen_states], dtype=object)
|
||||||
|
|
||||||
|
real_df = pd.DataFrame(real)
|
||||||
|
gen_df = pd.DataFrame(gen)
|
||||||
|
|
||||||
|
feature_types = {**{c: "continuous" for c in cont_cols}, **{d: "discrete" for d in disc_cols}}
|
||||||
|
return real_df, gen_df, feature_types, disc_vocab_map
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# 2) Metric computation
|
||||||
|
# ----------------------------
|
||||||
|
|
||||||
|
def lag1_autocorr(x: np.ndarray) -> float:
|
||||||
|
"""Lag-1 Pearson autocorrelation (robust to constant arrays)."""
|
||||||
|
x = np.asarray(x, dtype=np.float64)
|
||||||
|
if x.size < 3:
|
||||||
|
return np.nan
|
||||||
|
x0 = x[:-1]
|
||||||
|
x1 = x[1:]
|
||||||
|
s0 = np.std(x0)
|
||||||
|
s1 = np.std(x1)
|
||||||
|
if s0 < 1e-12 or s1 < 1e-12:
|
||||||
|
return 0.0
|
||||||
|
return float(np.corrcoef(x0, x1)[0, 1])
|
||||||
|
|
||||||
|
def jsd_discrete(real_tokens: np.ndarray, gen_tokens: np.ndarray, vocab: List[str], base: float = 2.0) -> float:
|
||||||
|
"""Jensen–Shannon divergence for discrete distributions over vocab."""
|
||||||
|
# empirical pmf with Laplace smoothing to avoid zeros
|
||||||
|
eps = 1e-12
|
||||||
|
r = pd.Series(real_tokens).value_counts()
|
||||||
|
g = pd.Series(gen_tokens).value_counts()
|
||||||
|
p = np.array([r.get(v, 0) for v in vocab], dtype=np.float64)
|
||||||
|
q = np.array([g.get(v, 0) for v in vocab], dtype=np.float64)
|
||||||
|
p = p + eps
|
||||||
|
q = q + eps
|
||||||
|
p = p / p.sum()
|
||||||
|
q = q / q.sum()
|
||||||
|
# scipy returns sqrt(JS); square it to get JS divergence
|
||||||
|
return float(jensenshannon(p, q, base=base) ** 2)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Metrics:
|
||||||
|
feature: str
|
||||||
|
ftype: str
|
||||||
|
ks: float | np.nan
|
||||||
|
jsd: float | np.nan
|
||||||
|
lag1_diff: float | np.nan
|
||||||
|
|
||||||
|
def compute_metrics(
|
||||||
|
real_df: pd.DataFrame,
|
||||||
|
gen_df: pd.DataFrame,
|
||||||
|
feature_types: Dict[str, str],
|
||||||
|
disc_vocab_map: Dict[str, List[str]] | None = None,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
rows: List[Metrics] = []
|
||||||
|
|
||||||
|
for col in real_df.columns:
|
||||||
|
ftype = feature_types[col]
|
||||||
|
r = real_df[col].to_numpy()
|
||||||
|
g = gen_df[col].to_numpy()
|
||||||
|
|
||||||
|
if ftype == "continuous":
|
||||||
|
# KS on raw values
|
||||||
|
ks_val = float(ks_2samp(r.astype(np.float64), g.astype(np.float64), alternative="two-sided").statistic)
|
||||||
|
# lag-1 diff
|
||||||
|
lag_r = lag1_autocorr(r)
|
||||||
|
lag_g = lag1_autocorr(g)
|
||||||
|
lag_diff = float(abs(lag_r - lag_g))
|
||||||
|
rows.append(Metrics(col, ftype, ks_val, np.nan, lag_diff))
|
||||||
|
|
||||||
|
elif ftype == "discrete":
|
||||||
|
# JSD on categorical pmf
|
||||||
|
if disc_vocab_map is None or col not in disc_vocab_map:
|
||||||
|
vocab = sorted(list(set(r).union(set(g))))
|
||||||
|
else:
|
||||||
|
vocab = disc_vocab_map[col]
|
||||||
|
jsd_val = jsd_discrete(r, g, vocab=vocab, base=2.0)
|
||||||
|
|
||||||
|
# optional: lag1 diff on integer encoding (captures “stickiness”; still interpretable)
|
||||||
|
# if you prefer lag1 only for continuous, set lag_diff=np.nan here.
|
||||||
|
mapping = {tok: i for i, tok in enumerate(vocab)}
|
||||||
|
r_int = np.array([mapping[t] for t in r], dtype=np.float64)
|
||||||
|
g_int = np.array([mapping[t] for t in g], dtype=np.float64)
|
||||||
|
lag_r = lag1_autocorr(r_int)
|
||||||
|
lag_g = lag1_autocorr(g_int)
|
||||||
|
lag_diff = float(abs(lag_r - lag_g))
|
||||||
|
|
||||||
|
rows.append(Metrics(col, ftype, np.nan, jsd_val, lag_diff))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown feature type: {ftype}")
|
||||||
|
|
||||||
|
df = pd.DataFrame([r.__dict__ for r in rows])
|
||||||
|
|
||||||
|
# Overall (separate by metric meaning)
|
||||||
|
overall = {
|
||||||
|
"ks_avg_cont": float(df.loc[df.ftype == "continuous", "ks"].mean(skipna=True)),
|
||||||
|
"jsd_avg_disc": float(df.loc[df.ftype == "discrete", "jsd"].mean(skipna=True)),
|
||||||
|
"lag1_avg_all": float(df["lag1_diff"].mean(skipna=True)),
|
||||||
|
}
|
||||||
|
print("Overall averages:", overall)
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# 3) Fancy but honest visualizations
|
||||||
|
# ----------------------------
|
||||||
|
|
||||||
|
def _ensure_dir(p: str) -> None:
|
||||||
|
os.makedirs(p, exist_ok=True)
|
||||||
|
|
||||||
|
def plot_heatmap_metrics(metrics_df: pd.DataFrame, outdir: str, title: str = "Per-feature benchmark (normalized)") -> None:
|
||||||
|
"""
|
||||||
|
Heatmap of per-feature metrics.
|
||||||
|
We normalize each metric column to [0,1] so KS/JSD/lag1 are comparable visually.
|
||||||
|
"""
|
||||||
|
_ensure_dir(outdir)
|
||||||
|
|
||||||
|
df = metrics_df.copy()
|
||||||
|
|
||||||
|
# Create a unified table with columns present for all features
|
||||||
|
# For missing metric values (e.g., ks for discrete), keep NaN; normalize per metric ignoring NaNs.
|
||||||
|
show_cols = ["ks", "jsd", "lag1_diff"]
|
||||||
|
|
||||||
|
# sort features by "worst" normalized score (max across available metrics)
|
||||||
|
norm = df[show_cols].copy()
|
||||||
|
for c in show_cols:
|
||||||
|
col = norm[c]
|
||||||
|
if col.notna().sum() == 0:
|
||||||
|
continue
|
||||||
|
mn = float(col.min(skipna=True))
|
||||||
|
mx = float(col.max(skipna=True))
|
||||||
|
if abs(mx - mn) < 1e-12:
|
||||||
|
norm[c] = 0.0
|
||||||
|
else:
|
||||||
|
norm[c] = (col - mn) / (mx - mn)
|
||||||
|
df["_severity"] = norm.max(axis=1, skipna=True)
|
||||||
|
df = df.sort_values("_severity", ascending=False).drop(columns=["_severity"])
|
||||||
|
norm = norm.loc[df.index]
|
||||||
|
|
||||||
|
# Plot heatmap with matplotlib (no seaborn)
|
||||||
|
fig, ax = plt.subplots(figsize=(9.5, max(4.0, 0.28 * len(df))))
|
||||||
|
im = ax.imshow(norm.to_numpy(), aspect="auto", interpolation="nearest")
|
||||||
|
|
||||||
|
ax.set_yticks(np.arange(len(df)))
|
||||||
|
ax.set_yticklabels(df["feature"].tolist(), fontsize=8)
|
||||||
|
|
||||||
|
ax.set_xticks(np.arange(len(show_cols)))
|
||||||
|
ax.set_xticklabels(show_cols, fontsize=10)
|
||||||
|
|
||||||
|
ax.set_title(title, fontsize=13, pad=12)
|
||||||
|
cbar = fig.colorbar(im, ax=ax, fraction=0.035, pad=0.02)
|
||||||
|
cbar.set_label("min–max normalized (per metric)", rotation=90)
|
||||||
|
|
||||||
|
# Light gridlines for readability
|
||||||
|
ax.set_xticks(np.arange(-.5, len(show_cols), 1), minor=True)
|
||||||
|
ax.set_yticks(np.arange(-.5, len(df), 1), minor=True)
|
||||||
|
ax.grid(which="minor", color="white", linestyle="-", linewidth=0.6)
|
||||||
|
ax.tick_params(which="minor", bottom=False, left=False)
|
||||||
|
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig(os.path.join(outdir, "heatmap_metrics_normalized.png"), dpi=300)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
def plot_topk_lollipop(metrics_df: pd.DataFrame, metric: str, outdir: str, k: int = 15) -> None:
|
||||||
|
"""Fancy top-k lollipop chart for a metric (e.g., 'ks' or 'jsd' or 'lag1_diff')."""
|
||||||
|
_ensure_dir(outdir)
|
||||||
|
df = metrics_df.dropna(subset=[metric]).sort_values(metric, ascending=False).head(k)
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(9.5, 0.45 * len(df) + 1.5))
|
||||||
|
y = np.arange(len(df))[::-1]
|
||||||
|
x = df[metric].to_numpy()
|
||||||
|
|
||||||
|
# stems
|
||||||
|
ax.hlines(y=y, xmin=0, xmax=x, linewidth=2.0, alpha=0.85)
|
||||||
|
# heads
|
||||||
|
ax.plot(x, y, "o", markersize=7)
|
||||||
|
|
||||||
|
ax.set_yticks(y)
|
||||||
|
ax.set_yticklabels(df["feature"].tolist(), fontsize=9)
|
||||||
|
ax.set_xlabel(metric)
|
||||||
|
ax.set_title(f"Top-{k} features by {metric}", fontsize=13, pad=10)
|
||||||
|
|
||||||
|
# subtle x-grid
|
||||||
|
ax.grid(axis="x", linestyle="--", alpha=0.4)
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig(os.path.join(outdir, f"topk_{metric}_lollipop.png"), dpi=300)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
def plot_metric_distributions(metrics_df: pd.DataFrame, outdir: str) -> None:
|
||||||
|
"""Histogram distributions of each metric across features."""
|
||||||
|
_ensure_dir(outdir)
|
||||||
|
metrics = ["ks", "jsd", "lag1_diff"]
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(9.5, 5.2))
|
||||||
|
# overlay histograms (alpha) for compact comparison
|
||||||
|
for m in metrics:
|
||||||
|
vals = metrics_df[m].dropna().to_numpy()
|
||||||
|
if len(vals) == 0:
|
||||||
|
continue
|
||||||
|
ax.hist(vals, bins=20, alpha=0.55, density=True, label=m)
|
||||||
|
|
||||||
|
ax.set_title("Distribution of per-feature benchmark metrics", fontsize=13, pad=10)
|
||||||
|
ax.set_xlabel("metric value")
|
||||||
|
ax.set_ylabel("density")
|
||||||
|
ax.grid(axis="y", linestyle="--", alpha=0.35)
|
||||||
|
ax.legend()
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig(os.path.join(outdir, "metric_distributions.png"), dpi=300)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
def plot_overall_summary(metrics_df: pd.DataFrame, outdir: str) -> None:
|
||||||
|
"""Simple overall averages (separate meaning per metric)."""
|
||||||
|
_ensure_dir(outdir)
|
||||||
|
ks_avg = metrics_df.loc[metrics_df.ftype == "continuous", "ks"].mean(skipna=True)
|
||||||
|
jsd_avg = metrics_df.loc[metrics_df.ftype == "discrete", "jsd"].mean(skipna=True)
|
||||||
|
lag_avg = metrics_df["lag1_diff"].mean(skipna=True)
|
||||||
|
|
||||||
|
labels = ["KS(avg cont)", "JSD(avg disc)", "Lag1 diff(avg)"]
|
||||||
|
values = [ks_avg, jsd_avg, lag_avg]
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(7.8, 4.6))
|
||||||
|
ax.bar(labels, values)
|
||||||
|
ax.set_title("Overall benchmark summary", fontsize=13, pad=10)
|
||||||
|
ax.set_ylabel("value")
|
||||||
|
ax.grid(axis="y", linestyle="--", alpha=0.35)
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig(os.path.join(outdir, "overall_summary_bar.png"), dpi=300)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# 4) continuous trends
|
||||||
|
# ----------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _smooth_ma(x: np.ndarray, w: int) -> np.ndarray:
|
||||||
|
"""Simple moving average smoothing; w>=1."""
|
||||||
|
if w <= 1:
|
||||||
|
return x.astype(np.float64, copy=True)
|
||||||
|
s = pd.Series(x.astype(np.float64))
|
||||||
|
return s.rolling(w, center=True, min_periods=max(2, w // 10)).mean().to_numpy()
|
||||||
|
|
||||||
|
def _rolling_quantile(x: np.ndarray, w: int, q: float) -> np.ndarray:
|
||||||
|
"""Rolling quantile for a local envelope."""
|
||||||
|
if w <= 1:
|
||||||
|
return x.astype(np.float64, copy=True)
|
||||||
|
s = pd.Series(x.astype(np.float64))
|
||||||
|
return s.rolling(w, center=True, min_periods=max(10, w // 5)).quantile(q).to_numpy()
|
||||||
|
|
||||||
|
def plot_trend_small_multiples(
|
||||||
|
real_df: pd.DataFrame,
|
||||||
|
gen_df: pd.DataFrame,
|
||||||
|
metrics_df: pd.DataFrame,
|
||||||
|
feature_types: Dict[str, str],
|
||||||
|
outdir: str,
|
||||||
|
k: int = 12,
|
||||||
|
rank_by: str = "lag1_diff", # or "ks"
|
||||||
|
smooth_w: int = 200,
|
||||||
|
env_w: int = 600,
|
||||||
|
downsample: int = 1,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Small multiples of continuous-feature trends.
|
||||||
|
- Picks top-k continuous features by rank_by (ks or lag1_diff).
|
||||||
|
- Plots: faint raw + bold smoothed + rolling 10–90% envelope for real and gen.
|
||||||
|
"""
|
||||||
|
_ensure_dir(outdir)
|
||||||
|
|
||||||
|
# pick top-k continuous features
|
||||||
|
dfc = metrics_df[metrics_df["ftype"] == "continuous"].dropna(subset=[rank_by]).copy()
|
||||||
|
dfc = dfc.sort_values(rank_by, ascending=False).head(k)
|
||||||
|
feats = dfc["feature"].tolist()
|
||||||
|
if not feats:
|
||||||
|
print("[trend] No continuous features found to plot.")
|
||||||
|
return
|
||||||
|
|
||||||
|
n = len(feats)
|
||||||
|
ncols = 3
|
||||||
|
nrows = int(np.ceil(n / ncols))
|
||||||
|
|
||||||
|
fig = plt.subplots(figsize=(12, 3.2 * nrows))[0]
|
||||||
|
gs = fig.add_gridspec(nrows, ncols, hspace=0.35, wspace=0.22)
|
||||||
|
|
||||||
|
for idx, feat in enumerate(feats):
|
||||||
|
r = real_df[feat].to_numpy(dtype=np.float64)
|
||||||
|
g = gen_df[feat].to_numpy(dtype=np.float64)
|
||||||
|
|
||||||
|
if downsample > 1:
|
||||||
|
r = r[::downsample]
|
||||||
|
g = g[::downsample]
|
||||||
|
|
||||||
|
t = np.arange(len(r), dtype=np.int64)
|
||||||
|
|
||||||
|
# smoothed trend
|
||||||
|
r_tr = _smooth_ma(r, smooth_w)
|
||||||
|
g_tr = _smooth_ma(g, smooth_w)
|
||||||
|
|
||||||
|
# local envelope (10–90%)
|
||||||
|
r_lo = _rolling_quantile(r, env_w, 0.10)
|
||||||
|
r_hi = _rolling_quantile(r, env_w, 0.90)
|
||||||
|
g_lo = _rolling_quantile(g, env_w, 0.10)
|
||||||
|
g_hi = _rolling_quantile(g, env_w, 0.90)
|
||||||
|
|
||||||
|
ax = fig.add_subplot(gs[idx // ncols, idx % ncols])
|
||||||
|
|
||||||
|
# faint raw “watermark”
|
||||||
|
ax.plot(t, r, linewidth=0.6, alpha=0.18, label="real (raw)" if idx == 0 else None)
|
||||||
|
ax.plot(t, g, linewidth=0.6, alpha=0.18, label="gen (raw)" if idx == 0 else None)
|
||||||
|
|
||||||
|
# envelopes
|
||||||
|
ax.fill_between(t, r_lo, r_hi, alpha=0.18, label="real (10–90%)" if idx == 0 else None)
|
||||||
|
ax.fill_between(t, g_lo, g_hi, alpha=0.18, label="gen (10–90%)" if idx == 0 else None)
|
||||||
|
|
||||||
|
# bold trends
|
||||||
|
ax.plot(t, r_tr, linewidth=2.0, alpha=0.9, label="real (trend)" if idx == 0 else None)
|
||||||
|
ax.plot(t, g_tr, linewidth=2.0, alpha=0.9, label="gen (trend)" if idx == 0 else None)
|
||||||
|
|
||||||
|
# title with metric
|
||||||
|
mval = float(dfc.loc[dfc["feature"] == feat, rank_by].iloc[0])
|
||||||
|
ax.set_title(f"{feat} | {rank_by}={mval:.3f}", fontsize=10)
|
||||||
|
|
||||||
|
# cosmetics
|
||||||
|
ax.grid(axis="y", linestyle="--", alpha=0.25)
|
||||||
|
ax.tick_params(labelsize=8)
|
||||||
|
|
||||||
|
# shared legend (single)
|
||||||
|
handles, labels = fig.axes[0].get_legend_handles_labels()
|
||||||
|
fig.legend(handles, labels, loc="upper center", ncol=3, frameon=False, bbox_to_anchor=(0.5, 0.99))
|
||||||
|
fig.suptitle("Trend comparison (continuous features): raw watermark + smoothed trend + local envelope",
|
||||||
|
fontsize=14, y=1.02)
|
||||||
|
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig(os.path.join(outdir, f"trend_small_multiples_{rank_by}.png"), dpi=300, bbox_inches="tight")
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# 4) discrete trends
|
||||||
|
# ----------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def plot_discrete_state_trends(
|
||||||
|
real_df: pd.DataFrame,
|
||||||
|
gen_df: pd.DataFrame,
|
||||||
|
metrics_df: pd.DataFrame,
|
||||||
|
feature_types: Dict[str, str],
|
||||||
|
disc_vocab_map: Dict[str, List[str]] | None,
|
||||||
|
outdir: str,
|
||||||
|
k: int = 6,
|
||||||
|
rank_by: str = "jsd",
|
||||||
|
win: int = 400,
|
||||||
|
step: int = 40,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
For top-k discrete features by JSD, plot rolling state occupancy over time:
|
||||||
|
p_t(v) = fraction of state v in window centered at t
|
||||||
|
"""
|
||||||
|
_ensure_dir(outdir)
|
||||||
|
|
||||||
|
dfd = metrics_df[metrics_df["ftype"] == "discrete"].dropna(subset=[rank_by]).copy()
|
||||||
|
dfd = dfd.sort_values(rank_by, ascending=False).head(k)
|
||||||
|
feats = dfd["feature"].tolist()
|
||||||
|
if not feats:
|
||||||
|
print("[trend] No discrete features found to plot.")
|
||||||
|
return
|
||||||
|
|
||||||
|
for feat in feats:
|
||||||
|
r = real_df[feat].to_numpy()
|
||||||
|
g = gen_df[feat].to_numpy()
|
||||||
|
vocab = disc_vocab_map.get(feat) if disc_vocab_map else sorted(list(set(r).union(set(g))))
|
||||||
|
|
||||||
|
# timeline centers
|
||||||
|
centers = np.arange(win // 2, len(r) - win // 2, step, dtype=np.int64)
|
||||||
|
t = centers
|
||||||
|
|
||||||
|
# occupancy matrices: shape [len(vocab), len(t)]
|
||||||
|
R = np.zeros((len(vocab), len(t)), dtype=np.float64)
|
||||||
|
G = np.zeros((len(vocab), len(t)), dtype=np.float64)
|
||||||
|
|
||||||
|
for ti, c in enumerate(centers):
|
||||||
|
lo = c - win // 2
|
||||||
|
hi = c + win // 2
|
||||||
|
rw = r[lo:hi]
|
||||||
|
gw = g[lo:hi]
|
||||||
|
|
||||||
|
# counts
|
||||||
|
rc = pd.Series(rw).value_counts()
|
||||||
|
gc = pd.Series(gw).value_counts()
|
||||||
|
for vi, v in enumerate(vocab):
|
||||||
|
R[vi, ti] = rc.get(v, 0) / max(1, len(rw))
|
||||||
|
G[vi, ti] = gc.get(v, 0) / max(1, len(gw))
|
||||||
|
|
||||||
|
# Plot as two stacked areas: Real vs Gen (same vocab order)
|
||||||
|
fig, ax = plt.subplots(figsize=(11.5, 4.8))
|
||||||
|
ax.stackplot(t, R, alpha=0.35, labels=[f"{v} (real)" for v in vocab])
|
||||||
|
ax.stackplot(t, G, alpha=0.35, labels=[f"{v} (gen)" for v in vocab])
|
||||||
|
|
||||||
|
mval = float(dfd.loc[dfd["feature"] == feat, rank_by].iloc[0])
|
||||||
|
ax.set_title(f"Discrete trend via rolling state occupancy — {feat} | {rank_by}={mval:.3f}", fontsize=13, pad=10)
|
||||||
|
ax.set_xlabel("time index (window centers)")
|
||||||
|
ax.set_ylabel("occupancy probability")
|
||||||
|
ax.grid(axis="y", linestyle="--", alpha=0.25)
|
||||||
|
|
||||||
|
# Put legend outside to keep plot clean
|
||||||
|
ax.legend(loc="center left", bbox_to_anchor=(1.01, 0.5), frameon=False, fontsize=8)
|
||||||
|
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig(os.path.join(outdir, f"disc_state_trend_{feat}.png"), dpi=300, bbox_inches="tight")
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# 4) Main entry
|
||||||
|
# ----------------------------
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
outdir = "benchmark_figs"
|
||||||
|
_ensure_dir(outdir)
|
||||||
|
real_df, gen_df, feature_types, disc_vocab_map = make_dummy_real_gen(
|
||||||
|
T=10000, n_cont=20, n_disc=8, seed=42
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics_df = compute_metrics(real_df, gen_df, feature_types, disc_vocab_map)
|
||||||
|
metrics_df.to_csv(os.path.join(outdir, "metrics_per_feature.csv"), index=False)
|
||||||
|
|
||||||
|
plot_heatmap_metrics(metrics_df, outdir)
|
||||||
|
plot_metric_distributions(metrics_df, outdir)
|
||||||
|
plot_overall_summary(metrics_df, outdir)
|
||||||
|
|
||||||
|
# Top-k charts: choose what you emphasize
|
||||||
|
plot_topk_lollipop(metrics_df, metric="ks", outdir=outdir, k=15)
|
||||||
|
plot_topk_lollipop(metrics_df, metric="jsd", outdir=outdir, k=15)
|
||||||
|
plot_topk_lollipop(metrics_df, metric="lag1_diff", outdir=outdir, k=15)
|
||||||
|
|
||||||
|
# Fancy trend figures (new)
|
||||||
|
plot_trend_small_multiples(
|
||||||
|
real_df, gen_df, metrics_df, feature_types, outdir,
|
||||||
|
k=12, rank_by="lag1_diff", smooth_w=250, env_w=800, downsample=1
|
||||||
|
)
|
||||||
|
|
||||||
|
plot_trend_small_multiples(
|
||||||
|
real_df, gen_df, metrics_df, feature_types, outdir,
|
||||||
|
k=12, rank_by="ks", smooth_w=250, env_w=800, downsample=1
|
||||||
|
)
|
||||||
|
|
||||||
|
plot_discrete_state_trends(
|
||||||
|
real_df, gen_df, metrics_df, feature_types, disc_vocab_map, outdir,
|
||||||
|
k=6, rank_by="jsd", win=500, step=50
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Saved figures + CSV under: ./{outdir}/")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user