3 Commits

Author SHA1 Message Date
f9358ebb0f Bug fix 2026-02-02 16:35:23 +08:00
fb318b3f8c Add visualization/README.md 2026-02-02 16:33:58 +08:00
914865b838 Upload files to "visualization" 2026-02-02 16:16:10 +08:00
3 changed files with 632 additions and 0 deletions

45
visualization/README.md Normal file
View File

@@ -0,0 +1,45 @@
# 可视化(仅展示思路)
> **Note** ymz你仔细看看下面的对话别重蹈覆辙。
先说明一件事直接让gpt画有诱导性的图他会罢工注意**黑体**部分。我的prompt如下
So, as you know. We used ks, jsd and lag1 diff for the benchmark(per feature, and a avg version for overall benchmark). And I want to visualize these in the benchmark section. For the visualization, I want it to be fancy graph/chart that **make it hard for reader to see the gap between real_value and generated_value.** Can you write python code for me(you can fill with dummy data)?
他是这么回复的:
**I cant help you design figures intentionally to make the gap between real vs. generated look smaller—that would be misleading.** What I can do is give you publication-quality, “fancy” visualizations that are aesthetically strong while still being honest and reproducible (e.g., normalized heatmaps, rank-based summaries, and compact distributions that remain comparable across features).
所以到时候找他写画图代码的话,记得组织一下语言。像我后面用这种方式让他画了和我原意接近的:
I think current heatmap/distribution/lollipop charts are not fancy enough. Like, **I want you to add the one that shows the trend**(current charts are about ks/jsd/lag1 diff values)
# 如何运行
不想用uv管理环境的话下面这些是依赖
```text
numpy
pandas
matplotlib
scipy
textwrap3
```
如果用uv下面是相关命令
> **Note** 注意此时你应该处在 visualization/ 下
```bash
# 下载依赖
uv sync
# 运行
uv run ./vis_benchmark.py
```
**图片会出现在 `visualization/benchmark_figs/`**
## 我的想法
既然只有jsd好看我们就用类似trend的风格去掉上面标记的ks = xxx, lag1 = xxx.这样视觉上相近,又不能说我们错
**有更好的可视化方案随意补充**

View File

@@ -0,0 +1,15 @@
[project]
name = "mask-ddpm"
version = "0.0.0"
description = "Hybrid diffusion example for ICS traffic feature generation."
requires-python = ">=3.12"
dependencies = [
"matplotlib>=3.10.8",
"numpy>=2.4.2",
"pandas>=3.0.0",
"scipy>=1.17.0",
"textwrap3>=0.9.2",
]
[tool.uv]
dev-dependencies = []

View File

@@ -0,0 +1,572 @@
#!/usr/bin/env python3
"""
Benchmark visualization for per-feature KS / JSD / lag-1 diff + overall averages.
- Uses dummy data generators (AR(1) for continuous; Markov chain for discrete)
- Replace `real_df` and `gen_df` with your real/generated samples.
- Outputs publication-quality figures into ./benchmark_figs/
Dependencies:
pip install numpy pandas matplotlib scipy
(Optional for nicer label wrapping)
pip install textwrap3
"""
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import ks_2samp
from scipy.spatial.distance import jensenshannon
# ----------------------------
# 1) Dummy data generation
# ----------------------------
def ar1_series(T: int, phi: float, sigma: float, mu: float = 0.0, seed: int | None = None) -> np.ndarray:
rng = np.random.default_rng(seed)
x = np.zeros(T, dtype=np.float64)
eps = rng.normal(0.0, sigma, size=T)
x[0] = mu + eps[0]
for t in range(1, T):
x[t] = mu + phi * (x[t - 1] - mu) + eps[t]
return x
def markov_chain(T: int, P: np.ndarray, pi0: np.ndarray, seed: int | None = None) -> np.ndarray:
"""Return integer states 0..K-1."""
rng = np.random.default_rng(seed)
K = P.shape[0]
states = np.zeros(T, dtype=np.int64)
states[0] = rng.choice(K, p=pi0)
for t in range(1, T):
states[t] = rng.choice(K, p=P[states[t - 1]])
return states
def make_dummy_real_gen(
T: int = 8000,
n_cont: int = 18,
n_disc: int = 6,
disc_vocab_sizes: List[int] | None = None,
seed: int = 7,
) -> Tuple[pd.DataFrame, pd.DataFrame, Dict[str, str], Dict[str, List[str]]]:
"""
Returns:
real_df, gen_df, feature_types, disc_vocab_map
where:
feature_types[col] in {"continuous","discrete"}
disc_vocab_map[col] = list of token strings (for discrete vars)
"""
rng = np.random.default_rng(seed)
cont_cols = [f"c{i:02d}" for i in range(n_cont)]
disc_cols = [f"d{i:02d}" for i in range(n_disc)]
if disc_vocab_sizes is None:
base = [3, 4, 5, 3, 6, 4]
disc_vocab_sizes = (base + [4] * max(0, n_disc - len(base)))[:n_disc]
elif len(disc_vocab_sizes) < n_disc:
# If you pass a shorter list, extend it safely
disc_vocab_sizes = list(disc_vocab_sizes) + [disc_vocab_sizes[-1]] * (n_disc - len(disc_vocab_sizes))
real = {}
gen = {}
# Continuous: AR(1) with slightly different parameters for generated
for i, col in enumerate(cont_cols):
phi_real = rng.uniform(0.75, 0.98)
sigma_real = rng.uniform(0.3, 1.2)
mu_real = rng.uniform(-1.0, 1.0)
# Generated is close but not identical
phi_gen = np.clip(phi_real + rng.normal(0, 0.02), 0.60, 0.995)
sigma_gen = max(0.05, sigma_real * rng.uniform(0.85, 1.15))
mu_gen = mu_real + rng.normal(0, 0.15)
real[col] = ar1_series(T, phi_real, sigma_real, mu_real, seed=seed + 100 + i)
gen[col] = ar1_series(T, phi_gen, sigma_gen, mu_gen, seed=seed + 200 + i)
# Discrete: Markov chains with slightly perturbed transition matrices
disc_vocab_map: Dict[str, List[str]] = {}
for j, col in enumerate(disc_cols):
K = disc_vocab_sizes[j]
disc_vocab_map[col] = [f"s{k}" for k in range(K)]
# Random but stable transition matrix for real
A = rng.uniform(0.1, 1.0, size=(K, K))
P_real = A / A.sum(axis=1, keepdims=True)
pi0 = rng.uniform(0.1, 1.0, size=K)
pi0 = pi0 / pi0.sum()
# Perturb for generated
noise = rng.normal(0, 0.06, size=(K, K))
P_gen = np.clip(P_real + noise, 1e-6, None)
P_gen = P_gen / P_gen.sum(axis=1, keepdims=True)
real_states = markov_chain(T, P_real, pi0, seed=seed + 300 + j)
gen_states = markov_chain(T, P_gen, pi0, seed=seed + 400 + j)
# Store as tokens (strings) to emphasize discrete nature
real[col] = np.array([disc_vocab_map[col][s] for s in real_states], dtype=object)
gen[col] = np.array([disc_vocab_map[col][s] for s in gen_states], dtype=object)
real_df = pd.DataFrame(real)
gen_df = pd.DataFrame(gen)
feature_types = {**{c: "continuous" for c in cont_cols}, **{d: "discrete" for d in disc_cols}}
return real_df, gen_df, feature_types, disc_vocab_map
# ----------------------------
# 2) Metric computation
# ----------------------------
def lag1_autocorr(x: np.ndarray) -> float:
"""Lag-1 Pearson autocorrelation (robust to constant arrays)."""
x = np.asarray(x, dtype=np.float64)
if x.size < 3:
return np.nan
x0 = x[:-1]
x1 = x[1:]
s0 = np.std(x0)
s1 = np.std(x1)
if s0 < 1e-12 or s1 < 1e-12:
return 0.0
return float(np.corrcoef(x0, x1)[0, 1])
def jsd_discrete(real_tokens: np.ndarray, gen_tokens: np.ndarray, vocab: List[str], base: float = 2.0) -> float:
"""JensenShannon divergence for discrete distributions over vocab."""
# empirical pmf with Laplace smoothing to avoid zeros
eps = 1e-12
r = pd.Series(real_tokens).value_counts()
g = pd.Series(gen_tokens).value_counts()
p = np.array([r.get(v, 0) for v in vocab], dtype=np.float64)
q = np.array([g.get(v, 0) for v in vocab], dtype=np.float64)
p = p + eps
q = q + eps
p = p / p.sum()
q = q / q.sum()
# scipy returns sqrt(JS); square it to get JS divergence
return float(jensenshannon(p, q, base=base) ** 2)
@dataclass
class Metrics:
feature: str
ftype: str
ks: float | np.nan
jsd: float | np.nan
lag1_diff: float | np.nan
def compute_metrics(
real_df: pd.DataFrame,
gen_df: pd.DataFrame,
feature_types: Dict[str, str],
disc_vocab_map: Dict[str, List[str]] | None = None,
) -> pd.DataFrame:
rows: List[Metrics] = []
for col in real_df.columns:
ftype = feature_types[col]
r = real_df[col].to_numpy()
g = gen_df[col].to_numpy()
if ftype == "continuous":
# KS on raw values
ks_val = float(ks_2samp(r.astype(np.float64), g.astype(np.float64), alternative="two-sided").statistic)
# lag-1 diff
lag_r = lag1_autocorr(r)
lag_g = lag1_autocorr(g)
lag_diff = float(abs(lag_r - lag_g))
rows.append(Metrics(col, ftype, ks_val, np.nan, lag_diff))
elif ftype == "discrete":
# JSD on categorical pmf
if disc_vocab_map is None or col not in disc_vocab_map:
vocab = sorted(list(set(r).union(set(g))))
else:
vocab = disc_vocab_map[col]
jsd_val = jsd_discrete(r, g, vocab=vocab, base=2.0)
# optional: lag1 diff on integer encoding (captures “stickiness”; still interpretable)
# if you prefer lag1 only for continuous, set lag_diff=np.nan here.
mapping = {tok: i for i, tok in enumerate(vocab)}
r_int = np.array([mapping[t] for t in r], dtype=np.float64)
g_int = np.array([mapping[t] for t in g], dtype=np.float64)
lag_r = lag1_autocorr(r_int)
lag_g = lag1_autocorr(g_int)
lag_diff = float(abs(lag_r - lag_g))
rows.append(Metrics(col, ftype, np.nan, jsd_val, lag_diff))
else:
raise ValueError(f"Unknown feature type: {ftype}")
df = pd.DataFrame([r.__dict__ for r in rows])
# Overall (separate by metric meaning)
overall = {
"ks_avg_cont": float(df.loc[df.ftype == "continuous", "ks"].mean(skipna=True)),
"jsd_avg_disc": float(df.loc[df.ftype == "discrete", "jsd"].mean(skipna=True)),
"lag1_avg_all": float(df["lag1_diff"].mean(skipna=True)),
}
print("Overall averages:", overall)
return df
# ----------------------------
# 3) Fancy but honest visualizations
# ----------------------------
def _ensure_dir(p: str) -> None:
os.makedirs(p, exist_ok=True)
def plot_heatmap_metrics(metrics_df: pd.DataFrame, outdir: str, title: str = "Per-feature benchmark (normalized)") -> None:
"""
Heatmap of per-feature metrics.
We normalize each metric column to [0,1] so KS/JSD/lag1 are comparable visually.
"""
_ensure_dir(outdir)
df = metrics_df.copy()
# Create a unified table with columns present for all features
# For missing metric values (e.g., ks for discrete), keep NaN; normalize per metric ignoring NaNs.
show_cols = ["ks", "jsd", "lag1_diff"]
# sort features by "worst" normalized score (max across available metrics)
norm = df[show_cols].copy()
for c in show_cols:
col = norm[c]
if col.notna().sum() == 0:
continue
mn = float(col.min(skipna=True))
mx = float(col.max(skipna=True))
if abs(mx - mn) < 1e-12:
norm[c] = 0.0
else:
norm[c] = (col - mn) / (mx - mn)
df["_severity"] = norm.max(axis=1, skipna=True)
df = df.sort_values("_severity", ascending=False).drop(columns=["_severity"])
norm = norm.loc[df.index]
# Plot heatmap with matplotlib (no seaborn)
fig, ax = plt.subplots(figsize=(9.5, max(4.0, 0.28 * len(df))))
im = ax.imshow(norm.to_numpy(), aspect="auto", interpolation="nearest")
ax.set_yticks(np.arange(len(df)))
ax.set_yticklabels(df["feature"].tolist(), fontsize=8)
ax.set_xticks(np.arange(len(show_cols)))
ax.set_xticklabels(show_cols, fontsize=10)
ax.set_title(title, fontsize=13, pad=12)
cbar = fig.colorbar(im, ax=ax, fraction=0.035, pad=0.02)
cbar.set_label("minmax normalized (per metric)", rotation=90)
# Light gridlines for readability
ax.set_xticks(np.arange(-.5, len(show_cols), 1), minor=True)
ax.set_yticks(np.arange(-.5, len(df), 1), minor=True)
ax.grid(which="minor", color="white", linestyle="-", linewidth=0.6)
ax.tick_params(which="minor", bottom=False, left=False)
fig.tight_layout()
fig.savefig(os.path.join(outdir, "heatmap_metrics_normalized.png"), dpi=300)
plt.close(fig)
def plot_topk_lollipop(metrics_df: pd.DataFrame, metric: str, outdir: str, k: int = 15) -> None:
"""Fancy top-k lollipop chart for a metric (e.g., 'ks' or 'jsd' or 'lag1_diff')."""
_ensure_dir(outdir)
df = metrics_df.dropna(subset=[metric]).sort_values(metric, ascending=False).head(k)
fig, ax = plt.subplots(figsize=(9.5, 0.45 * len(df) + 1.5))
y = np.arange(len(df))[::-1]
x = df[metric].to_numpy()
# stems
ax.hlines(y=y, xmin=0, xmax=x, linewidth=2.0, alpha=0.85)
# heads
ax.plot(x, y, "o", markersize=7)
ax.set_yticks(y)
ax.set_yticklabels(df["feature"].tolist(), fontsize=9)
ax.set_xlabel(metric)
ax.set_title(f"Top-{k} features by {metric}", fontsize=13, pad=10)
# subtle x-grid
ax.grid(axis="x", linestyle="--", alpha=0.4)
fig.tight_layout()
fig.savefig(os.path.join(outdir, f"topk_{metric}_lollipop.png"), dpi=300)
plt.close(fig)
def plot_metric_distributions(metrics_df: pd.DataFrame, outdir: str) -> None:
"""Histogram distributions of each metric across features."""
_ensure_dir(outdir)
metrics = ["ks", "jsd", "lag1_diff"]
fig, ax = plt.subplots(figsize=(9.5, 5.2))
# overlay histograms (alpha) for compact comparison
for m in metrics:
vals = metrics_df[m].dropna().to_numpy()
if len(vals) == 0:
continue
ax.hist(vals, bins=20, alpha=0.55, density=True, label=m)
ax.set_title("Distribution of per-feature benchmark metrics", fontsize=13, pad=10)
ax.set_xlabel("metric value")
ax.set_ylabel("density")
ax.grid(axis="y", linestyle="--", alpha=0.35)
ax.legend()
fig.tight_layout()
fig.savefig(os.path.join(outdir, "metric_distributions.png"), dpi=300)
plt.close(fig)
def plot_overall_summary(metrics_df: pd.DataFrame, outdir: str) -> None:
"""Simple overall averages (separate meaning per metric)."""
_ensure_dir(outdir)
ks_avg = metrics_df.loc[metrics_df.ftype == "continuous", "ks"].mean(skipna=True)
jsd_avg = metrics_df.loc[metrics_df.ftype == "discrete", "jsd"].mean(skipna=True)
lag_avg = metrics_df["lag1_diff"].mean(skipna=True)
labels = ["KS(avg cont)", "JSD(avg disc)", "Lag1 diff(avg)"]
values = [ks_avg, jsd_avg, lag_avg]
fig, ax = plt.subplots(figsize=(7.8, 4.6))
ax.bar(labels, values)
ax.set_title("Overall benchmark summary", fontsize=13, pad=10)
ax.set_ylabel("value")
ax.grid(axis="y", linestyle="--", alpha=0.35)
fig.tight_layout()
fig.savefig(os.path.join(outdir, "overall_summary_bar.png"), dpi=300)
plt.close(fig)
# ----------------------------
# 4) continuous trends
# ----------------------------
def _smooth_ma(x: np.ndarray, w: int) -> np.ndarray:
"""Simple moving average smoothing; w>=1."""
if w <= 1:
return x.astype(np.float64, copy=True)
s = pd.Series(x.astype(np.float64))
return s.rolling(w, center=True, min_periods=max(2, w // 10)).mean().to_numpy()
def _rolling_quantile(x: np.ndarray, w: int, q: float) -> np.ndarray:
"""Rolling quantile for a local envelope."""
if w <= 1:
return x.astype(np.float64, copy=True)
s = pd.Series(x.astype(np.float64))
return s.rolling(w, center=True, min_periods=max(10, w // 5)).quantile(q).to_numpy()
def plot_trend_small_multiples(
real_df: pd.DataFrame,
gen_df: pd.DataFrame,
metrics_df: pd.DataFrame,
feature_types: Dict[str, str],
outdir: str,
k: int = 12,
rank_by: str = "lag1_diff", # or "ks"
smooth_w: int = 200,
env_w: int = 600,
downsample: int = 1,
) -> None:
"""
Small multiples of continuous-feature trends.
- Picks top-k continuous features by rank_by (ks or lag1_diff).
- Plots: faint raw + bold smoothed + rolling 1090% envelope for real and gen.
"""
_ensure_dir(outdir)
# pick top-k continuous features
dfc = metrics_df[metrics_df["ftype"] == "continuous"].dropna(subset=[rank_by]).copy()
dfc = dfc.sort_values(rank_by, ascending=False).head(k)
feats = dfc["feature"].tolist()
if not feats:
print("[trend] No continuous features found to plot.")
return
n = len(feats)
ncols = 3
nrows = int(np.ceil(n / ncols))
fig = plt.subplots(figsize=(12, 3.2 * nrows))[0]
gs = fig.add_gridspec(nrows, ncols, hspace=0.35, wspace=0.22)
for idx, feat in enumerate(feats):
r = real_df[feat].to_numpy(dtype=np.float64)
g = gen_df[feat].to_numpy(dtype=np.float64)
if downsample > 1:
r = r[::downsample]
g = g[::downsample]
t = np.arange(len(r), dtype=np.int64)
# smoothed trend
r_tr = _smooth_ma(r, smooth_w)
g_tr = _smooth_ma(g, smooth_w)
# local envelope (1090%)
r_lo = _rolling_quantile(r, env_w, 0.10)
r_hi = _rolling_quantile(r, env_w, 0.90)
g_lo = _rolling_quantile(g, env_w, 0.10)
g_hi = _rolling_quantile(g, env_w, 0.90)
ax = fig.add_subplot(gs[idx // ncols, idx % ncols])
# faint raw “watermark”
ax.plot(t, r, linewidth=0.6, alpha=0.18, label="real (raw)" if idx == 0 else None)
ax.plot(t, g, linewidth=0.6, alpha=0.18, label="gen (raw)" if idx == 0 else None)
# envelopes
ax.fill_between(t, r_lo, r_hi, alpha=0.18, label="real (1090%)" if idx == 0 else None)
ax.fill_between(t, g_lo, g_hi, alpha=0.18, label="gen (1090%)" if idx == 0 else None)
# bold trends
ax.plot(t, r_tr, linewidth=2.0, alpha=0.9, label="real (trend)" if idx == 0 else None)
ax.plot(t, g_tr, linewidth=2.0, alpha=0.9, label="gen (trend)" if idx == 0 else None)
# title with metric
mval = float(dfc.loc[dfc["feature"] == feat, rank_by].iloc[0])
ax.set_title(f"{feat} | {rank_by}={mval:.3f}", fontsize=10)
# cosmetics
ax.grid(axis="y", linestyle="--", alpha=0.25)
ax.tick_params(labelsize=8)
# shared legend (single)
handles, labels = fig.axes[0].get_legend_handles_labels()
fig.legend(handles, labels, loc="upper center", ncol=3, frameon=False, bbox_to_anchor=(0.5, 0.99))
fig.suptitle("Trend comparison (continuous features): raw watermark + smoothed trend + local envelope",
fontsize=14, y=1.02)
fig.tight_layout()
fig.savefig(os.path.join(outdir, f"trend_small_multiples_{rank_by}.png"), dpi=300, bbox_inches="tight")
plt.close(fig)
# ----------------------------
# 4) discrete trends
# ----------------------------
def plot_discrete_state_trends(
real_df: pd.DataFrame,
gen_df: pd.DataFrame,
metrics_df: pd.DataFrame,
feature_types: Dict[str, str],
disc_vocab_map: Dict[str, List[str]] | None,
outdir: str,
k: int = 6,
rank_by: str = "jsd",
win: int = 400,
step: int = 40,
) -> None:
"""
For top-k discrete features by JSD, plot rolling state occupancy over time:
p_t(v) = fraction of state v in window centered at t
"""
_ensure_dir(outdir)
dfd = metrics_df[metrics_df["ftype"] == "discrete"].dropna(subset=[rank_by]).copy()
dfd = dfd.sort_values(rank_by, ascending=False).head(k)
feats = dfd["feature"].tolist()
if not feats:
print("[trend] No discrete features found to plot.")
return
for feat in feats:
r = real_df[feat].to_numpy()
g = gen_df[feat].to_numpy()
vocab = disc_vocab_map.get(feat) if disc_vocab_map else sorted(list(set(r).union(set(g))))
# timeline centers
centers = np.arange(win // 2, len(r) - win // 2, step, dtype=np.int64)
t = centers
# occupancy matrices: shape [len(vocab), len(t)]
R = np.zeros((len(vocab), len(t)), dtype=np.float64)
G = np.zeros((len(vocab), len(t)), dtype=np.float64)
for ti, c in enumerate(centers):
lo = c - win // 2
hi = c + win // 2
rw = r[lo:hi]
gw = g[lo:hi]
# counts
rc = pd.Series(rw).value_counts()
gc = pd.Series(gw).value_counts()
for vi, v in enumerate(vocab):
R[vi, ti] = rc.get(v, 0) / max(1, len(rw))
G[vi, ti] = gc.get(v, 0) / max(1, len(gw))
# Plot as two stacked areas: Real vs Gen (same vocab order)
fig, ax = plt.subplots(figsize=(11.5, 4.8))
ax.stackplot(t, R, alpha=0.35, labels=[f"{v} (real)" for v in vocab])
ax.stackplot(t, G, alpha=0.35, labels=[f"{v} (gen)" for v in vocab])
mval = float(dfd.loc[dfd["feature"] == feat, rank_by].iloc[0])
ax.set_title(f"Discrete trend via rolling state occupancy — {feat} | {rank_by}={mval:.3f}", fontsize=13, pad=10)
ax.set_xlabel("time index (window centers)")
ax.set_ylabel("occupancy probability")
ax.grid(axis="y", linestyle="--", alpha=0.25)
# Put legend outside to keep plot clean
ax.legend(loc="center left", bbox_to_anchor=(1.01, 0.5), frameon=False, fontsize=8)
fig.tight_layout()
fig.savefig(os.path.join(outdir, f"disc_state_trend_{feat}.png"), dpi=300, bbox_inches="tight")
plt.close(fig)
# ----------------------------
# 4) Main entry
# ----------------------------
def main() -> None:
outdir = "benchmark_figs"
_ensure_dir(outdir)
real_df, gen_df, feature_types, disc_vocab_map = make_dummy_real_gen(
T=10000, n_cont=20, n_disc=8, seed=42
)
metrics_df = compute_metrics(real_df, gen_df, feature_types, disc_vocab_map)
metrics_df.to_csv(os.path.join(outdir, "metrics_per_feature.csv"), index=False)
plot_heatmap_metrics(metrics_df, outdir)
plot_metric_distributions(metrics_df, outdir)
plot_overall_summary(metrics_df, outdir)
# Top-k charts: choose what you emphasize
plot_topk_lollipop(metrics_df, metric="ks", outdir=outdir, k=15)
plot_topk_lollipop(metrics_df, metric="jsd", outdir=outdir, k=15)
plot_topk_lollipop(metrics_df, metric="lag1_diff", outdir=outdir, k=15)
# Fancy trend figures (new)
plot_trend_small_multiples(
real_df, gen_df, metrics_df, feature_types, outdir,
k=12, rank_by="lag1_diff", smooth_w=250, env_w=800, downsample=1
)
plot_trend_small_multiples(
real_df, gen_df, metrics_df, feature_types, outdir,
k=12, rank_by="ks", smooth_w=250, env_w=800, downsample=1
)
plot_discrete_state_trends(
real_df, gen_df, metrics_df, feature_types, disc_vocab_map, outdir,
k=6, rank_by="jsd", win=500, step=50
)
print(f"Saved figures + CSV under: ./{outdir}/")
if __name__ == "__main__":
main()