forked from manbo/internal-docs
Compare commits
1 Commits
latex-ieee
...
esorics
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
566e251743 |
1
arxiv-style/fig-scripts/.python-version
Normal file
1
arxiv-style/fig-scripts/.python-version
Normal file
@@ -0,0 +1 @@
|
|||||||
|
3.12
|
||||||
237
arxiv-style/fig-scripts/draw_channels.py
Normal file
237
arxiv-style/fig-scripts/draw_channels.py
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Draw *separate* SVG figures for:
|
||||||
|
1) Continuous channels (multiple smooth curves per figure)
|
||||||
|
2) Discrete channels (multiple step-like/token curves per figure)
|
||||||
|
|
||||||
|
Outputs (default):
|
||||||
|
out/continuous_channels.svg
|
||||||
|
out/discrete_channels.svg
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Transparent background (good for draw.io / LaTeX / diagrams).
|
||||||
|
- No axes/frames by default (diagram-friendly).
|
||||||
|
- Curves are synthetic placeholders; replace `make_*_channels()` with your real data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Data generators (placeholders)
|
||||||
|
# ----------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GenParams:
|
||||||
|
seconds: float = 10.0
|
||||||
|
fs: int = 200
|
||||||
|
seed: int = 7
|
||||||
|
n_cont: int = 6 # number of continuous channels (curves)
|
||||||
|
n_disc: int = 5 # number of discrete channels (curves)
|
||||||
|
disc_vocab: int = 8 # token/vocab size for discrete channels
|
||||||
|
disc_change_rate_hz: float = 1.2 # how often discrete tokens change
|
||||||
|
|
||||||
|
|
||||||
|
def make_continuous_channels(p: GenParams) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
t: shape (T,)
|
||||||
|
Y: shape (n_cont, T)
|
||||||
|
"""
|
||||||
|
rng = np.random.default_rng(p.seed)
|
||||||
|
T = int(p.seconds * p.fs)
|
||||||
|
t = np.linspace(0, p.seconds, T, endpoint=False)
|
||||||
|
|
||||||
|
Y = []
|
||||||
|
for i in range(p.n_cont):
|
||||||
|
# Multi-scale smooth-ish signals
|
||||||
|
f1 = 0.15 + 0.06 * i
|
||||||
|
f2 = 0.8 + 0.15 * (i % 3)
|
||||||
|
phase = rng.uniform(0, 2 * np.pi)
|
||||||
|
y = (
|
||||||
|
0.9 * np.sin(2 * np.pi * f1 * t + phase)
|
||||||
|
+ 0.35 * np.sin(2 * np.pi * f2 * t + 1.3 * phase)
|
||||||
|
)
|
||||||
|
# Add mild colored-ish noise by smoothing white noise
|
||||||
|
w = rng.normal(0, 1, size=T)
|
||||||
|
w = np.convolve(w, np.ones(9) / 9.0, mode="same")
|
||||||
|
y = y + 0.15 * w
|
||||||
|
|
||||||
|
# Normalize each channel for consistent visual scale
|
||||||
|
y = (y - np.mean(y)) / (np.std(y) + 1e-9)
|
||||||
|
y = 0.8 * y + 0.15 * i # vertical offset to separate curves a bit
|
||||||
|
Y.append(y)
|
||||||
|
|
||||||
|
return t, np.vstack(Y)
|
||||||
|
|
||||||
|
|
||||||
|
def make_discrete_channels(p: GenParams) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Discrete channels as piecewise-constant token IDs (integers).
|
||||||
|
Returns:
|
||||||
|
t: shape (T,)
|
||||||
|
X: shape (n_disc, T) (integers in [0, disc_vocab-1])
|
||||||
|
"""
|
||||||
|
rng = np.random.default_rng(p.seed + 100)
|
||||||
|
T = int(p.seconds * p.fs)
|
||||||
|
t = np.linspace(0, p.seconds, T, endpoint=False)
|
||||||
|
|
||||||
|
# expected number of changes per channel
|
||||||
|
expected_changes = int(max(1, p.seconds * p.disc_change_rate_hz))
|
||||||
|
|
||||||
|
X = np.zeros((p.n_disc, T), dtype=int)
|
||||||
|
for c in range(p.n_disc):
|
||||||
|
# pick change points
|
||||||
|
k = rng.poisson(expected_changes) + 1
|
||||||
|
change_pts = np.unique(rng.integers(0, T, size=k))
|
||||||
|
change_pts = np.sort(np.concatenate([[0], change_pts, [T]]))
|
||||||
|
|
||||||
|
cur = rng.integers(0, p.disc_vocab)
|
||||||
|
for a, b in zip(change_pts[:-1], change_pts[1:]):
|
||||||
|
# occasional token jump
|
||||||
|
if a != 0:
|
||||||
|
if rng.random() < 0.85:
|
||||||
|
cur = rng.integers(0, p.disc_vocab)
|
||||||
|
X[c, a:b] = cur
|
||||||
|
|
||||||
|
return t, X
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Plotting helpers
|
||||||
|
# ----------------------------
|
||||||
|
|
||||||
|
def _make_transparent_figure(width_in: float, height_in: float) -> tuple[plt.Figure, plt.Axes]:
|
||||||
|
fig = plt.figure(figsize=(width_in, height_in), dpi=200)
|
||||||
|
fig.patch.set_alpha(0.0)
|
||||||
|
ax = fig.add_axes([0.03, 0.03, 0.94, 0.94])
|
||||||
|
ax.patch.set_alpha(0.0)
|
||||||
|
return fig, ax
|
||||||
|
|
||||||
|
|
||||||
|
def save_continuous_channels_svg(
|
||||||
|
t: np.ndarray,
|
||||||
|
Y: np.ndarray,
|
||||||
|
out_path: Path,
|
||||||
|
*,
|
||||||
|
lw: float = 2.0,
|
||||||
|
clean: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Plot multiple continuous curves in one figure and save SVG.
|
||||||
|
Y shape: (n_cont, T)
|
||||||
|
"""
|
||||||
|
fig, ax = _make_transparent_figure(width_in=6.0, height_in=2.2)
|
||||||
|
|
||||||
|
# Let matplotlib choose different colors automatically (good defaults).
|
||||||
|
for i in range(Y.shape[0]):
|
||||||
|
ax.plot(t, Y[i], linewidth=lw)
|
||||||
|
|
||||||
|
if clean:
|
||||||
|
ax.set_axis_off()
|
||||||
|
else:
|
||||||
|
ax.set_xlabel("t")
|
||||||
|
ax.set_ylabel("value")
|
||||||
|
|
||||||
|
# Set limits with padding
|
||||||
|
y_all = Y.reshape(-1)
|
||||||
|
ymin, ymax = float(np.min(y_all)), float(np.max(y_all))
|
||||||
|
ypad = 0.08 * (ymax - ymin + 1e-9)
|
||||||
|
ax.set_xlim(t[0], t[-1])
|
||||||
|
ax.set_ylim(ymin - ypad, ymax + ypad)
|
||||||
|
|
||||||
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fig.savefig(out_path, format="svg", bbox_inches="tight", pad_inches=0.0, transparent=True)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def save_discrete_channels_svg(
|
||||||
|
t: np.ndarray,
|
||||||
|
X: np.ndarray,
|
||||||
|
out_path: Path,
|
||||||
|
*,
|
||||||
|
lw: float = 2.0,
|
||||||
|
clean: bool = True,
|
||||||
|
vertical_spacing: float = 1.25,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Plot multiple discrete (piecewise-constant) curves in one figure and save SVG.
|
||||||
|
X shape: (n_disc, T) integers.
|
||||||
|
|
||||||
|
We draw each channel as a step plot, offset vertically so curves don't overlap.
|
||||||
|
"""
|
||||||
|
fig, ax = _make_transparent_figure(width_in=6.0, height_in=2.2)
|
||||||
|
|
||||||
|
for i in range(X.shape[0]):
|
||||||
|
y = X[i].astype(float) + i * vertical_spacing
|
||||||
|
ax.step(t, y, where="post", linewidth=lw)
|
||||||
|
|
||||||
|
if clean:
|
||||||
|
ax.set_axis_off()
|
||||||
|
else:
|
||||||
|
ax.set_xlabel("t")
|
||||||
|
ax.set_ylabel("token id (offset)")
|
||||||
|
|
||||||
|
y_all = (X.astype(float) + np.arange(X.shape[0])[:, None] * vertical_spacing).reshape(-1)
|
||||||
|
ymin, ymax = float(np.min(y_all)), float(np.max(y_all))
|
||||||
|
ypad = 0.10 * (ymax - ymin + 1e-9)
|
||||||
|
ax.set_xlim(t[0], t[-1])
|
||||||
|
ax.set_ylim(ymin - ypad, ymax + ypad)
|
||||||
|
|
||||||
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fig.savefig(out_path, format="svg", bbox_inches="tight", pad_inches=0.0, transparent=True)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# CLI
|
||||||
|
# ----------------------------
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
ap = argparse.ArgumentParser()
|
||||||
|
ap.add_argument("--outdir", type=Path, default=Path("out"))
|
||||||
|
ap.add_argument("--seed", type=int, default=7)
|
||||||
|
ap.add_argument("--seconds", type=float, default=10.0)
|
||||||
|
ap.add_argument("--fs", type=int, default=200)
|
||||||
|
|
||||||
|
ap.add_argument("--n-cont", type=int, default=6)
|
||||||
|
ap.add_argument("--n-disc", type=int, default=5)
|
||||||
|
ap.add_argument("--disc-vocab", type=int, default=8)
|
||||||
|
ap.add_argument("--disc-change-rate", type=float, default=1.2)
|
||||||
|
|
||||||
|
ap.add_argument("--keep-axes", action="store_true", help="Show axes/labels (default: off)")
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
p = GenParams(
|
||||||
|
seconds=args.seconds,
|
||||||
|
fs=args.fs,
|
||||||
|
seed=args.seed,
|
||||||
|
n_cont=args.n_cont,
|
||||||
|
n_disc=args.n_disc,
|
||||||
|
disc_vocab=args.disc_vocab,
|
||||||
|
disc_change_rate_hz=args.disc_change_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
t_c, Y = make_continuous_channels(p)
|
||||||
|
t_d, X = make_discrete_channels(p)
|
||||||
|
|
||||||
|
cont_path = args.outdir / "continuous_channels.svg"
|
||||||
|
disc_path = args.outdir / "discrete_channels.svg"
|
||||||
|
|
||||||
|
save_continuous_channels_svg(t_c, Y, cont_path, clean=not args.keep_axes)
|
||||||
|
save_discrete_channels_svg(t_d, X, disc_path, clean=not args.keep_axes)
|
||||||
|
|
||||||
|
print("Wrote:")
|
||||||
|
print(f" {cont_path}")
|
||||||
|
print(f" {disc_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
272
arxiv-style/fig-scripts/draw_synthetic_ics_optionA.py
Normal file
272
arxiv-style/fig-scripts/draw_synthetic_ics_optionA.py
Normal file
@@ -0,0 +1,272 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Option A: "Synthetic ICS Data" mini-panel (high-level features, not packets)
|
||||||
|
|
||||||
|
What it draws (one SVG, transparent background):
|
||||||
|
- Top: 2–3 continuous feature curves (smooth, time-aligned)
|
||||||
|
- Bottom: discrete/categorical feature strip (colored blocks)
|
||||||
|
- One vertical dashed alignment line crossing both
|
||||||
|
- Optional shaded regime window
|
||||||
|
- Optional "real vs synthetic" ghost overlay (faint gray behind one curve)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
uv run python draw_synthetic_ics_optionA.py --out ./assets/synth_ics_optionA.svg
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.patches import Rectangle
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Params:
|
||||||
|
seed: int = 7
|
||||||
|
seconds: float = 10.0
|
||||||
|
fs: int = 300
|
||||||
|
|
||||||
|
n_curves: int = 3 # continuous channels shown
|
||||||
|
n_bins: int = 40 # discrete blocks across x
|
||||||
|
disc_vocab: int = 8 # number of discrete categories
|
||||||
|
|
||||||
|
# Layout / style
|
||||||
|
width_in: float = 6.0
|
||||||
|
height_in: float = 2.2
|
||||||
|
curve_lw: float = 2.3
|
||||||
|
ghost_lw: float = 2.0 # "real" overlay line width
|
||||||
|
strip_height: float = 0.65 # bar height in [0,1] strip axis
|
||||||
|
strip_gap_frac: float = 0.10 # gap between blocks (fraction of block width)
|
||||||
|
|
||||||
|
# Visual cues
|
||||||
|
show_alignment_line: bool = True
|
||||||
|
align_x_frac: float = 0.58 # where to place dashed line, fraction of timeline
|
||||||
|
show_regime_window: bool = True
|
||||||
|
regime_start_frac: float = 0.30
|
||||||
|
regime_end_frac: float = 0.45
|
||||||
|
show_real_ghost: bool = True # faint gray "real" behind first synthetic curve
|
||||||
|
|
||||||
|
|
||||||
|
def _smooth(x: np.ndarray, win: int) -> np.ndarray:
|
||||||
|
win = max(3, int(win) | 1) # odd
|
||||||
|
k = np.ones(win, dtype=float)
|
||||||
|
k /= k.sum()
|
||||||
|
return np.convolve(x, k, mode="same")
|
||||||
|
|
||||||
|
|
||||||
|
def make_continuous_curves(p: Params) -> tuple[np.ndarray, np.ndarray, np.ndarray | None]:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
t: (T,)
|
||||||
|
Y_syn: (n_curves, T) synthetic curves
|
||||||
|
y_real: (T,) or None optional "real" ghost curve (for one channel)
|
||||||
|
"""
|
||||||
|
rng = np.random.default_rng(p.seed)
|
||||||
|
T = int(p.seconds * p.fs)
|
||||||
|
t = np.linspace(0, p.seconds, T, endpoint=False)
|
||||||
|
|
||||||
|
Y = []
|
||||||
|
for i in range(p.n_curves):
|
||||||
|
# multi-scale smooth temporal patterns
|
||||||
|
f_slow = 0.09 + 0.03 * (i % 3)
|
||||||
|
f_mid = 0.65 + 0.18 * (i % 4)
|
||||||
|
ph = rng.uniform(0, 2 * np.pi)
|
||||||
|
|
||||||
|
y = (
|
||||||
|
0.95 * np.sin(2 * np.pi * f_slow * t + ph)
|
||||||
|
+ 0.30 * np.sin(2 * np.pi * f_mid * t + 0.7 * ph)
|
||||||
|
)
|
||||||
|
|
||||||
|
# regime-like bumps
|
||||||
|
bumps = np.zeros_like(t)
|
||||||
|
for _ in range(2):
|
||||||
|
mu = rng.uniform(0.8, p.seconds - 0.8)
|
||||||
|
sig = rng.uniform(0.35, 0.85)
|
||||||
|
bumps += np.exp(-0.5 * ((t - mu) / (sig + 1e-9)) ** 2)
|
||||||
|
y += 0.55 * bumps
|
||||||
|
|
||||||
|
# mild smooth noise
|
||||||
|
noise = _smooth(rng.normal(0, 1, size=T), win=int(p.fs * 0.04))
|
||||||
|
y += 0.10 * noise
|
||||||
|
|
||||||
|
# normalize for clean presentation
|
||||||
|
y = (y - y.mean()) / (y.std() + 1e-9)
|
||||||
|
y *= 0.42
|
||||||
|
Y.append(y)
|
||||||
|
|
||||||
|
Y_syn = np.vstack(Y)
|
||||||
|
|
||||||
|
# Optional "real" ghost: similar to first curve, but slightly different
|
||||||
|
y_real = None
|
||||||
|
if p.show_real_ghost:
|
||||||
|
base = Y_syn[0].copy()
|
||||||
|
drift = _smooth(rng.normal(0, 1, size=T), win=int(p.fs * 0.18))
|
||||||
|
drift = drift / (np.std(drift) + 1e-9)
|
||||||
|
y_real = base * 0.95 + 0.07 * drift
|
||||||
|
|
||||||
|
return t, Y_syn, y_real
|
||||||
|
|
||||||
|
|
||||||
|
def make_discrete_strip(p: Params) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Piecewise-constant categorical IDs across n_bins.
|
||||||
|
Returns:
|
||||||
|
ids: (n_bins,) in [0, disc_vocab-1]
|
||||||
|
"""
|
||||||
|
rng = np.random.default_rng(p.seed + 123)
|
||||||
|
n = p.n_bins
|
||||||
|
ids = np.zeros(n, dtype=int)
|
||||||
|
|
||||||
|
cur = rng.integers(0, p.disc_vocab)
|
||||||
|
for i in range(n):
|
||||||
|
# occasional change
|
||||||
|
if i == 0 or rng.random() < 0.28:
|
||||||
|
cur = rng.integers(0, p.disc_vocab)
|
||||||
|
ids[i] = cur
|
||||||
|
|
||||||
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
def _axes_clean(ax: plt.Axes) -> None:
|
||||||
|
"""Keep axes lines optional but remove all text/numbers (diagram-friendly)."""
|
||||||
|
ax.set_xlabel("")
|
||||||
|
ax.set_ylabel("")
|
||||||
|
ax.set_title("")
|
||||||
|
ax.set_xticks([])
|
||||||
|
ax.set_yticks([])
|
||||||
|
ax.tick_params(
|
||||||
|
axis="both",
|
||||||
|
which="both",
|
||||||
|
bottom=False,
|
||||||
|
left=False,
|
||||||
|
top=False,
|
||||||
|
right=False,
|
||||||
|
labelbottom=False,
|
||||||
|
labelleft=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def draw_optionA(out_path: Path, p: Params) -> None:
|
||||||
|
# Figure
|
||||||
|
fig = plt.figure(figsize=(p.width_in, p.height_in), dpi=200)
|
||||||
|
fig.patch.set_alpha(0.0)
|
||||||
|
|
||||||
|
# Two stacked axes (shared x)
|
||||||
|
ax_top = fig.add_axes([0.08, 0.32, 0.90, 0.62])
|
||||||
|
ax_bot = fig.add_axes([0.08, 0.12, 0.90, 0.16], sharex=ax_top)
|
||||||
|
ax_top.patch.set_alpha(0.0)
|
||||||
|
ax_bot.patch.set_alpha(0.0)
|
||||||
|
|
||||||
|
# Generate data
|
||||||
|
t, Y_syn, y_real = make_continuous_curves(p)
|
||||||
|
ids = make_discrete_strip(p)
|
||||||
|
|
||||||
|
x0, x1 = float(t[0]), float(t[-1])
|
||||||
|
span = x1 - x0
|
||||||
|
|
||||||
|
# Optional shaded regime window
|
||||||
|
if p.show_regime_window:
|
||||||
|
rs = x0 + p.regime_start_frac * span
|
||||||
|
re = x0 + p.regime_end_frac * span
|
||||||
|
ax_top.axvspan(rs, re, alpha=0.12) # default color, semi-transparent
|
||||||
|
ax_bot.axvspan(rs, re, alpha=0.12)
|
||||||
|
|
||||||
|
# Optional vertical dashed alignment line
|
||||||
|
if p.show_alignment_line:
|
||||||
|
vx = x0 + p.align_x_frac * span
|
||||||
|
ax_top.axvline(vx, linestyle="--", linewidth=1.2, alpha=0.7)
|
||||||
|
ax_bot.axvline(vx, linestyle="--", linewidth=1.2, alpha=0.7)
|
||||||
|
|
||||||
|
# Continuous curves (use fixed colors for consistency)
|
||||||
|
curve_colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#9467bd"] # blue, orange, green, purple
|
||||||
|
|
||||||
|
# Ghost "real" behind the first curve (faint gray)
|
||||||
|
if y_real is not None:
|
||||||
|
ax_top.plot(t, y_real, linewidth=p.ghost_lw, color="0.65", alpha=0.55, zorder=1)
|
||||||
|
|
||||||
|
for i in range(Y_syn.shape[0]):
|
||||||
|
ax_top.plot(
|
||||||
|
t, Y_syn[i],
|
||||||
|
linewidth=p.curve_lw,
|
||||||
|
color=curve_colors[i % len(curve_colors)],
|
||||||
|
zorder=2
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set top y-limits with padding
|
||||||
|
ymin, ymax = float(Y_syn.min()), float(Y_syn.max())
|
||||||
|
ypad = 0.10 * (ymax - ymin + 1e-9)
|
||||||
|
ax_top.set_xlim(x0, x1)
|
||||||
|
ax_top.set_ylim(ymin - ypad, ymax + ypad)
|
||||||
|
|
||||||
|
# Discrete strip as colored blocks
|
||||||
|
palette = [
|
||||||
|
"#e41a1c", "#377eb8", "#4daf4a", "#984ea3",
|
||||||
|
"#ff7f00", "#ffff33", "#a65628", "#f781bf",
|
||||||
|
]
|
||||||
|
|
||||||
|
n = len(ids)
|
||||||
|
bin_w = span / n
|
||||||
|
gap = p.strip_gap_frac * bin_w
|
||||||
|
ax_bot.set_ylim(0, 1)
|
||||||
|
|
||||||
|
y = (1 - p.strip_height) / 2
|
||||||
|
for i, cat in enumerate(ids):
|
||||||
|
left = x0 + i * bin_w + gap / 2
|
||||||
|
width = bin_w - gap
|
||||||
|
ax_bot.add_patch(
|
||||||
|
Rectangle(
|
||||||
|
(left, y), width, p.strip_height,
|
||||||
|
facecolor=palette[int(cat) % len(palette)],
|
||||||
|
edgecolor="none",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean axes: no ticks/labels; keep spines (axes lines) visible
|
||||||
|
_axes_clean(ax_top)
|
||||||
|
_axes_clean(ax_bot)
|
||||||
|
for ax in (ax_top, ax_bot):
|
||||||
|
for side in ("left", "bottom", "top", "right"):
|
||||||
|
ax.spines[side].set_visible(True)
|
||||||
|
|
||||||
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fig.savefig(out_path, format="svg", transparent=True, bbox_inches="tight", pad_inches=0.0)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
ap = argparse.ArgumentParser()
|
||||||
|
ap.add_argument("--out", type=Path, default=Path("synth_ics_optionA.svg"))
|
||||||
|
ap.add_argument("--seed", type=int, default=7)
|
||||||
|
ap.add_argument("--seconds", type=float, default=10.0)
|
||||||
|
ap.add_argument("--fs", type=int, default=300)
|
||||||
|
ap.add_argument("--curves", type=int, default=3)
|
||||||
|
ap.add_argument("--bins", type=int, default=40)
|
||||||
|
ap.add_argument("--vocab", type=int, default=8)
|
||||||
|
|
||||||
|
ap.add_argument("--no-align", action="store_true")
|
||||||
|
ap.add_argument("--no-regime", action="store_true")
|
||||||
|
ap.add_argument("--no-ghost", action="store_true")
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
p = Params(
|
||||||
|
seed=args.seed,
|
||||||
|
seconds=args.seconds,
|
||||||
|
fs=args.fs,
|
||||||
|
n_curves=args.curves,
|
||||||
|
n_bins=args.bins,
|
||||||
|
disc_vocab=args.vocab,
|
||||||
|
show_alignment_line=not args.no_align,
|
||||||
|
show_regime_window=not args.no_regime,
|
||||||
|
show_real_ghost=not args.no_ghost,
|
||||||
|
)
|
||||||
|
|
||||||
|
draw_optionA(args.out, p)
|
||||||
|
print(f"Wrote: {args.out}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
318
arxiv-style/fig-scripts/draw_synthetic_ics_optionB.py
Normal file
318
arxiv-style/fig-scripts/draw_synthetic_ics_optionB.py
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Option B: "Synthetic ICS Data" as a mini process-story strip (high-level features)
|
||||||
|
- ONE SVG, transparent background
|
||||||
|
- Two frames by default: "steady/normal" -> "disturbance/recovery"
|
||||||
|
- Each frame contains:
|
||||||
|
- Top: multiple continuous feature curves
|
||||||
|
- Bottom: discrete/categorical strip (colored blocks)
|
||||||
|
- A vertical dashed alignment line crossing both
|
||||||
|
- Optional shaded regime window
|
||||||
|
- A right-pointing arrow between frames
|
||||||
|
|
||||||
|
No text, no numbers (axes lines only). Good for draw.io embedding.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
uv run python draw_synthetic_ics_optionB.py --out ./assets/synth_ics_optionB.svg
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.patches import Rectangle, FancyArrowPatch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Params:
|
||||||
|
seed: int = 7
|
||||||
|
seconds: float = 8.0
|
||||||
|
fs: int = 250
|
||||||
|
|
||||||
|
# Two-frame story
|
||||||
|
n_frames: int = 2
|
||||||
|
|
||||||
|
# Per-frame visuals
|
||||||
|
n_curves: int = 3
|
||||||
|
n_bins: int = 32
|
||||||
|
disc_vocab: int = 8
|
||||||
|
|
||||||
|
# Layout
|
||||||
|
width_in: float = 8.2
|
||||||
|
height_in: float = 2.4
|
||||||
|
# Relative layout inside the figure
|
||||||
|
margin_left: float = 0.05
|
||||||
|
margin_right: float = 0.05
|
||||||
|
margin_bottom: float = 0.12
|
||||||
|
margin_top: float = 0.10
|
||||||
|
frame_gap: float = 0.08 # gap (figure fraction) between frames (space for arrow)
|
||||||
|
|
||||||
|
# Styling
|
||||||
|
curve_lw: float = 2.1
|
||||||
|
ghost_lw: float = 1.8
|
||||||
|
strip_height: float = 0.65
|
||||||
|
strip_gap_frac: float = 0.12
|
||||||
|
|
||||||
|
# Cues
|
||||||
|
show_alignment_line: bool = True
|
||||||
|
align_x_frac: float = 0.60
|
||||||
|
show_regime_window: bool = True
|
||||||
|
regime_start_frac: float = 0.30
|
||||||
|
regime_end_frac: float = 0.46
|
||||||
|
show_real_ghost: bool = False # keep default off for cleaner story
|
||||||
|
show_axes_spines: bool = True # axes lines only (no ticks/labels)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- helpers ----------
|
||||||
|
|
||||||
|
def _smooth(x: np.ndarray, win: int) -> np.ndarray:
|
||||||
|
win = max(3, int(win) | 1)
|
||||||
|
k = np.ones(win, dtype=float)
|
||||||
|
k /= k.sum()
|
||||||
|
return np.convolve(x, k, mode="same")
|
||||||
|
|
||||||
|
|
||||||
|
def _axes_only(ax: plt.Axes, *, keep_spines: bool) -> None:
|
||||||
|
ax.set_xlabel("")
|
||||||
|
ax.set_ylabel("")
|
||||||
|
ax.set_title("")
|
||||||
|
ax.set_xticks([])
|
||||||
|
ax.set_yticks([])
|
||||||
|
ax.tick_params(
|
||||||
|
axis="both",
|
||||||
|
which="both",
|
||||||
|
bottom=False,
|
||||||
|
left=False,
|
||||||
|
top=False,
|
||||||
|
right=False,
|
||||||
|
labelbottom=False,
|
||||||
|
labelleft=False,
|
||||||
|
)
|
||||||
|
ax.grid(False)
|
||||||
|
if keep_spines:
|
||||||
|
for s in ("left", "right", "top", "bottom"):
|
||||||
|
ax.spines[s].set_visible(True)
|
||||||
|
else:
|
||||||
|
for s in ("left", "right", "top", "bottom"):
|
||||||
|
ax.spines[s].set_visible(False)
|
||||||
|
|
||||||
|
|
||||||
|
def make_frame_continuous(seed: int, seconds: float, fs: int, n_curves: int, style: str) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
style:
|
||||||
|
- "steady": smoother, smaller bumps
|
||||||
|
- "disturb": larger bumps and more variance
|
||||||
|
"""
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
T = int(seconds * fs)
|
||||||
|
t = np.linspace(0, seconds, T, endpoint=False)
|
||||||
|
|
||||||
|
amp_bump = 0.40 if style == "steady" else 0.85
|
||||||
|
amp_noise = 0.09 if style == "steady" else 0.14
|
||||||
|
amp_scale = 0.38 if style == "steady" else 0.46
|
||||||
|
|
||||||
|
base_freqs = [0.10, 0.08, 0.12, 0.09]
|
||||||
|
mid_freqs = [0.65, 0.78, 0.90, 0.72]
|
||||||
|
|
||||||
|
Y = []
|
||||||
|
for i in range(n_curves):
|
||||||
|
f_slow = base_freqs[i % len(base_freqs)]
|
||||||
|
f_mid = mid_freqs[i % len(mid_freqs)]
|
||||||
|
ph = rng.uniform(0, 2 * np.pi)
|
||||||
|
|
||||||
|
y = (
|
||||||
|
0.95 * np.sin(2 * np.pi * f_slow * t + ph)
|
||||||
|
+ 0.28 * np.sin(2 * np.pi * f_mid * t + 0.65 * ph)
|
||||||
|
)
|
||||||
|
|
||||||
|
bumps = np.zeros_like(t)
|
||||||
|
n_bumps = 2 if style == "steady" else 3
|
||||||
|
for _ in range(n_bumps):
|
||||||
|
mu = rng.uniform(0.9, seconds - 0.9)
|
||||||
|
sig = rng.uniform(0.35, 0.75) if style == "steady" else rng.uniform(0.20, 0.55)
|
||||||
|
bumps += np.exp(-0.5 * ((t - mu) / (sig + 1e-9)) ** 2)
|
||||||
|
y += amp_bump * bumps
|
||||||
|
|
||||||
|
noise = _smooth(rng.normal(0, 1, size=T), win=int(fs * 0.04))
|
||||||
|
y += amp_noise * noise
|
||||||
|
|
||||||
|
y = (y - y.mean()) / (y.std() + 1e-9)
|
||||||
|
y *= amp_scale
|
||||||
|
Y.append(y)
|
||||||
|
|
||||||
|
return t, np.vstack(Y)
|
||||||
|
|
||||||
|
|
||||||
|
def make_frame_discrete(seed: int, n_bins: int, vocab: int, style: str) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
style:
|
||||||
|
- "steady": fewer transitions
|
||||||
|
- "disturb": more transitions
|
||||||
|
"""
|
||||||
|
rng = np.random.default_rng(seed + 111)
|
||||||
|
ids = np.zeros(n_bins, dtype=int)
|
||||||
|
|
||||||
|
p_change = 0.20 if style == "steady" else 0.38
|
||||||
|
cur = rng.integers(0, vocab)
|
||||||
|
for i in range(n_bins):
|
||||||
|
if i == 0 or rng.random() < p_change:
|
||||||
|
cur = rng.integers(0, vocab)
|
||||||
|
ids[i] = cur
|
||||||
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
def draw_frame(ax_top: plt.Axes, ax_bot: plt.Axes, t: np.ndarray, Y: np.ndarray, ids: np.ndarray, p: Params) -> None:
|
||||||
|
# Optional cues
|
||||||
|
x0, x1 = float(t[0]), float(t[-1])
|
||||||
|
span = x1 - x0
|
||||||
|
|
||||||
|
if p.show_regime_window:
|
||||||
|
rs = x0 + p.regime_start_frac * span
|
||||||
|
re = x0 + p.regime_end_frac * span
|
||||||
|
ax_top.axvspan(rs, re, alpha=0.12) # default color
|
||||||
|
ax_bot.axvspan(rs, re, alpha=0.12)
|
||||||
|
|
||||||
|
if p.show_alignment_line:
|
||||||
|
vx = x0 + p.align_x_frac * span
|
||||||
|
ax_top.axvline(vx, linestyle="--", linewidth=1.15, alpha=0.7)
|
||||||
|
ax_bot.axvline(vx, linestyle="--", linewidth=1.15, alpha=0.7)
|
||||||
|
|
||||||
|
# Curves
|
||||||
|
curve_colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#9467bd"]
|
||||||
|
for i in range(Y.shape[0]):
|
||||||
|
ax_top.plot(t, Y[i], linewidth=p.curve_lw, color=curve_colors[i % len(curve_colors)])
|
||||||
|
|
||||||
|
ymin, ymax = float(Y.min()), float(Y.max())
|
||||||
|
ypad = 0.10 * (ymax - ymin + 1e-9)
|
||||||
|
ax_top.set_xlim(x0, x1)
|
||||||
|
ax_top.set_ylim(ymin - ypad, ymax + ypad)
|
||||||
|
|
||||||
|
# Discrete strip
|
||||||
|
palette = [
|
||||||
|
"#e41a1c", "#377eb8", "#4daf4a", "#984ea3",
|
||||||
|
"#ff7f00", "#ffff33", "#a65628", "#f781bf",
|
||||||
|
]
|
||||||
|
|
||||||
|
ax_bot.set_xlim(x0, x1)
|
||||||
|
ax_bot.set_ylim(0, 1)
|
||||||
|
|
||||||
|
n = len(ids)
|
||||||
|
bin_w = span / n
|
||||||
|
gap = p.strip_gap_frac * bin_w
|
||||||
|
y = (1 - p.strip_height) / 2
|
||||||
|
|
||||||
|
for i, cat in enumerate(ids):
|
||||||
|
left = x0 + i * bin_w + gap / 2
|
||||||
|
width = bin_w - gap
|
||||||
|
ax_bot.add_patch(
|
||||||
|
Rectangle((left, y), width, p.strip_height, facecolor=palette[int(cat) % len(palette)], edgecolor="none")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Axes-only style
|
||||||
|
_axes_only(ax_top, keep_spines=p.show_axes_spines)
|
||||||
|
_axes_only(ax_bot, keep_spines=p.show_axes_spines)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- main drawing ----------
|
||||||
|
|
||||||
|
def draw_optionB(out_path: Path, p: Params) -> None:
|
||||||
|
fig = plt.figure(figsize=(p.width_in, p.height_in), dpi=200)
|
||||||
|
fig.patch.set_alpha(0.0)
|
||||||
|
|
||||||
|
# Compute frame layout in figure coordinates
|
||||||
|
# Each frame has two stacked axes: top curves and bottom strip.
|
||||||
|
usable_w = 1.0 - p.margin_left - p.margin_right
|
||||||
|
usable_h = 1.0 - p.margin_bottom - p.margin_top
|
||||||
|
|
||||||
|
# Leave gap between frames for arrow
|
||||||
|
total_gap = p.frame_gap * (p.n_frames - 1)
|
||||||
|
frame_w = (usable_w - total_gap) / p.n_frames
|
||||||
|
|
||||||
|
# Within each frame: vertical split
|
||||||
|
top_h = usable_h * 0.70
|
||||||
|
bot_h = usable_h * 0.18
|
||||||
|
v_gap = usable_h * 0.06
|
||||||
|
# bottoms
|
||||||
|
bot_y = p.margin_bottom
|
||||||
|
top_y = bot_y + bot_h + v_gap
|
||||||
|
|
||||||
|
axes_pairs = []
|
||||||
|
for f in range(p.n_frames):
|
||||||
|
left = p.margin_left + f * (frame_w + p.frame_gap)
|
||||||
|
ax_top = fig.add_axes([left, top_y, frame_w, top_h])
|
||||||
|
ax_bot = fig.add_axes([left, bot_y, frame_w, bot_h], sharex=ax_top)
|
||||||
|
ax_top.patch.set_alpha(0.0)
|
||||||
|
ax_bot.patch.set_alpha(0.0)
|
||||||
|
axes_pairs.append((ax_top, ax_bot))
|
||||||
|
|
||||||
|
# Data per frame
|
||||||
|
styles = ["steady", "disturb"] if p.n_frames == 2 else ["steady"] * (p.n_frames - 1) + ["disturb"]
|
||||||
|
for idx, ((ax_top, ax_bot), style) in enumerate(zip(axes_pairs, styles)):
|
||||||
|
t, Y = make_frame_continuous(p.seed + 10 * idx, p.seconds, p.fs, p.n_curves, style=style)
|
||||||
|
ids = make_frame_discrete(p.seed + 10 * idx, p.n_bins, p.disc_vocab, style=style)
|
||||||
|
draw_frame(ax_top, ax_bot, t, Y, ids, p)
|
||||||
|
|
||||||
|
# Add a visual arrow between frames (in figure coordinates)
|
||||||
|
if p.n_frames >= 2:
|
||||||
|
for f in range(p.n_frames - 1):
|
||||||
|
# center between frame f and f+1
|
||||||
|
x_left = p.margin_left + f * (frame_w + p.frame_gap) + frame_w
|
||||||
|
x_right = p.margin_left + (f + 1) * (frame_w + p.frame_gap)
|
||||||
|
x_mid = (x_left + x_right) / 2
|
||||||
|
# arrow y in the middle of the frame stack
|
||||||
|
y_mid = bot_y + (bot_h + v_gap + top_h) / 2
|
||||||
|
|
||||||
|
arr = FancyArrowPatch(
|
||||||
|
(x_mid - 0.015, y_mid),
|
||||||
|
(x_mid + 0.015, y_mid),
|
||||||
|
transform=fig.transFigure,
|
||||||
|
arrowstyle="-|>",
|
||||||
|
mutation_scale=18,
|
||||||
|
linewidth=1.6,
|
||||||
|
color="black",
|
||||||
|
)
|
||||||
|
fig.patches.append(arr)
|
||||||
|
|
||||||
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fig.savefig(out_path, format="svg", transparent=True, bbox_inches="tight", pad_inches=0.0)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
ap = argparse.ArgumentParser()
|
||||||
|
ap.add_argument("--out", type=Path, default=Path("synth_ics_optionB.svg"))
|
||||||
|
ap.add_argument("--seed", type=int, default=7)
|
||||||
|
ap.add_argument("--seconds", type=float, default=8.0)
|
||||||
|
ap.add_argument("--fs", type=int, default=250)
|
||||||
|
ap.add_argument("--frames", type=int, default=2, choices=[2, 3], help="2 or 3 frames (story strip)")
|
||||||
|
ap.add_argument("--curves", type=int, default=3)
|
||||||
|
ap.add_argument("--bins", type=int, default=32)
|
||||||
|
ap.add_argument("--vocab", type=int, default=8)
|
||||||
|
ap.add_argument("--no-align", action="store_true")
|
||||||
|
ap.add_argument("--no-regime", action="store_true")
|
||||||
|
ap.add_argument("--no-spines", action="store_true")
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
p = Params(
|
||||||
|
seed=args.seed,
|
||||||
|
seconds=args.seconds,
|
||||||
|
fs=args.fs,
|
||||||
|
n_frames=args.frames,
|
||||||
|
n_curves=args.curves,
|
||||||
|
n_bins=args.bins,
|
||||||
|
disc_vocab=args.vocab,
|
||||||
|
show_alignment_line=not args.no_align,
|
||||||
|
show_regime_window=not args.no_regime,
|
||||||
|
show_axes_spines=not args.no_spines,
|
||||||
|
)
|
||||||
|
|
||||||
|
draw_optionB(args.out, p)
|
||||||
|
print(f"Wrote: {args.out}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
201
arxiv-style/fig-scripts/draw_transformer_lower_half.py
Normal file
201
arxiv-style/fig-scripts/draw_transformer_lower_half.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Draw the *Transformer section* lower-half visuals:
|
||||||
|
- Continuous channels: multiple smooth curves (like the colored trend lines)
|
||||||
|
- Discrete channels: small colored bars/ticks along the bottom
|
||||||
|
|
||||||
|
Output: ONE SVG with transparent background, axes hidden.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
uv run python draw_transformer_lower_half.py --out ./assets/transformer_lower_half.svg
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.patches import Rectangle
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Params:
|
||||||
|
seed: int = 7
|
||||||
|
seconds: float = 10.0
|
||||||
|
fs: int = 300
|
||||||
|
|
||||||
|
# Continuous channels
|
||||||
|
n_curves: int = 3
|
||||||
|
curve_lw: float = 2.4
|
||||||
|
|
||||||
|
# Discrete bars
|
||||||
|
n_bins: int = 40 # number of discrete bars/ticks across time
|
||||||
|
bar_height: float = 0.11 # relative height inside bar strip axis
|
||||||
|
bar_gap: float = 0.08 # gap between bars (fraction of bar width)
|
||||||
|
|
||||||
|
# Canvas sizing
|
||||||
|
width_in: float = 5.8
|
||||||
|
height_in: float = 1.9
|
||||||
|
|
||||||
|
|
||||||
|
def _smooth(x: np.ndarray, win: int) -> np.ndarray:
|
||||||
|
win = max(3, int(win) | 1) # odd
|
||||||
|
k = np.ones(win, dtype=float)
|
||||||
|
k /= k.sum()
|
||||||
|
return np.convolve(x, k, mode="same")
|
||||||
|
|
||||||
|
|
||||||
|
def make_continuous_curves(p: Params) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Produce 3 smooth curves with gentle long-term temporal patterning.
|
||||||
|
Returns:
|
||||||
|
t: (T,)
|
||||||
|
Y: (n_curves, T)
|
||||||
|
"""
|
||||||
|
rng = np.random.default_rng(p.seed)
|
||||||
|
T = int(p.seconds * p.fs)
|
||||||
|
t = np.linspace(0, p.seconds, T, endpoint=False)
|
||||||
|
|
||||||
|
Y = []
|
||||||
|
base_freqs = [0.12, 0.09, 0.15]
|
||||||
|
mid_freqs = [0.65, 0.85, 0.75]
|
||||||
|
|
||||||
|
for i in range(p.n_curves):
|
||||||
|
f1 = base_freqs[i % len(base_freqs)]
|
||||||
|
f2 = mid_freqs[i % len(mid_freqs)]
|
||||||
|
ph = rng.uniform(0, 2 * np.pi)
|
||||||
|
|
||||||
|
# Smooth trend + mid wiggle
|
||||||
|
y = (
|
||||||
|
1.00 * np.sin(2 * np.pi * f1 * t + ph)
|
||||||
|
+ 0.35 * np.sin(2 * np.pi * f2 * t + 0.7 * ph)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add a couple of smooth bumps (like slow pattern changes)
|
||||||
|
bumps = np.zeros_like(t)
|
||||||
|
for _ in range(2):
|
||||||
|
mu = rng.uniform(0.8, p.seconds - 0.8)
|
||||||
|
sig = rng.uniform(0.35, 0.75)
|
||||||
|
bumps += np.exp(-0.5 * ((t - mu) / sig) ** 2)
|
||||||
|
y += 0.55 * bumps
|
||||||
|
|
||||||
|
# Mild smooth noise
|
||||||
|
noise = _smooth(rng.normal(0, 1, size=T), win=int(p.fs * 0.04))
|
||||||
|
y += 0.12 * noise
|
||||||
|
|
||||||
|
# Normalize and compress amplitude to fit nicely
|
||||||
|
y = (y - y.mean()) / (y.std() + 1e-9)
|
||||||
|
y *= 0.42
|
||||||
|
|
||||||
|
Y.append(y)
|
||||||
|
|
||||||
|
return t, np.vstack(Y)
|
||||||
|
|
||||||
|
|
||||||
|
def make_discrete_bars(p: Params) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Generate discrete "token-like" bars across time bins.
|
||||||
|
Returns:
|
||||||
|
ids: (n_bins,) integer category ids
|
||||||
|
"""
|
||||||
|
rng = np.random.default_rng(p.seed + 123)
|
||||||
|
n = p.n_bins
|
||||||
|
|
||||||
|
# A piecewise-constant sequence with occasional changes (looks like discrete channel)
|
||||||
|
ids = np.zeros(n, dtype=int)
|
||||||
|
cur = rng.integers(0, 8)
|
||||||
|
for i in range(n):
|
||||||
|
if i == 0 or rng.random() < 0.25:
|
||||||
|
cur = rng.integers(0, 8)
|
||||||
|
ids[i] = cur
|
||||||
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
def draw_transformer_lower_half_svg(out_path: Path, p: Params) -> None:
|
||||||
|
# --- Figure + transparent background ---
|
||||||
|
fig = plt.figure(figsize=(p.width_in, p.height_in), dpi=200)
|
||||||
|
fig.patch.set_alpha(0.0)
|
||||||
|
|
||||||
|
# Two stacked axes: curves (top), bars (bottom)
|
||||||
|
# Tight, diagram-style layout
|
||||||
|
ax_curves = fig.add_axes([0.06, 0.28, 0.90, 0.68]) # [left, bottom, width, height]
|
||||||
|
ax_bars = fig.add_axes([0.06, 0.10, 0.90, 0.14])
|
||||||
|
|
||||||
|
ax_curves.patch.set_alpha(0.0)
|
||||||
|
ax_bars.patch.set_alpha(0.0)
|
||||||
|
|
||||||
|
for ax in (ax_curves, ax_bars):
|
||||||
|
ax.set_axis_off()
|
||||||
|
|
||||||
|
# --- Data ---
|
||||||
|
t, Y = make_continuous_curves(p)
|
||||||
|
ids = make_discrete_bars(p)
|
||||||
|
|
||||||
|
# --- Continuous curves (explicit colors to match the “multi-colored” look) ---
|
||||||
|
# Feel free to swap these hex colors to match your figure theme.
|
||||||
|
curve_colors = ["#1f77b4", "#ff7f0e", "#2ca02c"] # blue / orange / green
|
||||||
|
|
||||||
|
for i in range(Y.shape[0]):
|
||||||
|
ax_curves.plot(t, Y[i], linewidth=p.curve_lw, color=curve_colors[i % len(curve_colors)])
|
||||||
|
|
||||||
|
# Set curve bounds with padding (keeps it clean)
|
||||||
|
ymin, ymax = float(Y.min()), float(Y.max())
|
||||||
|
pad = 0.10 * (ymax - ymin + 1e-9)
|
||||||
|
ax_curves.set_xlim(t[0], t[-1])
|
||||||
|
ax_curves.set_ylim(ymin - pad, ymax + pad)
|
||||||
|
|
||||||
|
# --- Discrete bars: small colored rectangles along the timeline ---
|
||||||
|
# A small palette for categories (repeats if more categories appear)
|
||||||
|
bar_palette = [
|
||||||
|
"#e41a1c", "#377eb8", "#4daf4a", "#984ea3",
|
||||||
|
"#ff7f00", "#ffff33", "#a65628", "#f781bf",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Convert bins into time spans
|
||||||
|
n = len(ids)
|
||||||
|
x0, x1 = t[0], t[-1]
|
||||||
|
total = x1 - x0
|
||||||
|
bin_w = total / n
|
||||||
|
gap = p.bar_gap * bin_w
|
||||||
|
|
||||||
|
# Draw bars in [0,1] y-space inside ax_bars
|
||||||
|
ax_bars.set_xlim(x0, x1)
|
||||||
|
ax_bars.set_ylim(0, 1)
|
||||||
|
|
||||||
|
for i, cat in enumerate(ids):
|
||||||
|
left = x0 + i * bin_w + gap / 2
|
||||||
|
width = bin_w - gap
|
||||||
|
color = bar_palette[int(cat) % len(bar_palette)]
|
||||||
|
rect = Rectangle(
|
||||||
|
(left, (1 - p.bar_height) / 2),
|
||||||
|
width,
|
||||||
|
p.bar_height,
|
||||||
|
facecolor=color,
|
||||||
|
edgecolor="none",
|
||||||
|
)
|
||||||
|
ax_bars.add_patch(rect)
|
||||||
|
|
||||||
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fig.savefig(out_path, format="svg", transparent=True, bbox_inches="tight", pad_inches=0.0)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
ap = argparse.ArgumentParser()
|
||||||
|
ap.add_argument("--out", type=Path, default=Path("transformer_lower_half.svg"))
|
||||||
|
ap.add_argument("--seed", type=int, default=7)
|
||||||
|
ap.add_argument("--seconds", type=float, default=10.0)
|
||||||
|
ap.add_argument("--fs", type=int, default=300)
|
||||||
|
ap.add_argument("--bins", type=int, default=40)
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
p = Params(seed=args.seed, seconds=args.seconds, fs=args.fs, n_bins=args.bins)
|
||||||
|
draw_transformer_lower_half_svg(args.out, p)
|
||||||
|
print(f"Wrote: {args.out}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
202
arxiv-style/fig-scripts/draw_transformer_lower_half_axes.py
Normal file
202
arxiv-style/fig-scripts/draw_transformer_lower_half_axes.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Transformer section lower-half visuals WITH AXES ONLY:
|
||||||
|
- Axes spines visible
|
||||||
|
- NO numbers (tick labels hidden)
|
||||||
|
- NO words (axis labels removed)
|
||||||
|
- Transparent background
|
||||||
|
- One SVG output
|
||||||
|
|
||||||
|
Run:
|
||||||
|
uv run python draw_transformer_lower_half_axes_only.py --out ./assets/transformer_lower_half_axes_only.svg
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.patches import Rectangle
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Params:
|
||||||
|
seed: int = 7
|
||||||
|
seconds: float = 10.0
|
||||||
|
fs: int = 300
|
||||||
|
|
||||||
|
# Continuous channels
|
||||||
|
n_curves: int = 3
|
||||||
|
curve_lw: float = 2.4
|
||||||
|
|
||||||
|
# Discrete bars
|
||||||
|
n_bins: int = 40
|
||||||
|
bar_height: float = 0.55 # fraction of the discrete-axis y-range
|
||||||
|
bar_gap: float = 0.08 # fraction of bar width
|
||||||
|
|
||||||
|
# Figure size
|
||||||
|
width_in: float = 6.6
|
||||||
|
height_in: float = 2.6
|
||||||
|
|
||||||
|
|
||||||
|
def _smooth(x: np.ndarray, win: int) -> np.ndarray:
|
||||||
|
win = max(3, int(win) | 1) # odd
|
||||||
|
k = np.ones(win, dtype=float)
|
||||||
|
k /= k.sum()
|
||||||
|
return np.convolve(x, k, mode="same")
|
||||||
|
|
||||||
|
|
||||||
|
def make_continuous_curves(p: Params) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
rng = np.random.default_rng(p.seed)
|
||||||
|
T = int(p.seconds * p.fs)
|
||||||
|
t = np.linspace(0, p.seconds, T, endpoint=False)
|
||||||
|
|
||||||
|
Y = []
|
||||||
|
base_freqs = [0.12, 0.09, 0.15]
|
||||||
|
mid_freqs = [0.65, 0.85, 0.75]
|
||||||
|
|
||||||
|
for i in range(p.n_curves):
|
||||||
|
f1 = base_freqs[i % len(base_freqs)]
|
||||||
|
f2 = mid_freqs[i % len(mid_freqs)]
|
||||||
|
ph = rng.uniform(0, 2 * np.pi)
|
||||||
|
|
||||||
|
y = (
|
||||||
|
1.00 * np.sin(2 * np.pi * f1 * t + ph)
|
||||||
|
+ 0.35 * np.sin(2 * np.pi * f2 * t + 0.7 * ph)
|
||||||
|
)
|
||||||
|
|
||||||
|
bumps = np.zeros_like(t)
|
||||||
|
for _ in range(2):
|
||||||
|
mu = rng.uniform(0.8, p.seconds - 0.8)
|
||||||
|
sig = rng.uniform(0.35, 0.75)
|
||||||
|
bumps += np.exp(-0.5 * ((t - mu) / sig) ** 2)
|
||||||
|
y += 0.55 * bumps
|
||||||
|
|
||||||
|
noise = _smooth(rng.normal(0, 1, size=T), win=int(p.fs * 0.04))
|
||||||
|
y += 0.12 * noise
|
||||||
|
|
||||||
|
y = (y - y.mean()) / (y.std() + 1e-9)
|
||||||
|
y *= 0.42
|
||||||
|
Y.append(y)
|
||||||
|
|
||||||
|
return t, np.vstack(Y)
|
||||||
|
|
||||||
|
|
||||||
|
def make_discrete_bars(p: Params) -> np.ndarray:
|
||||||
|
rng = np.random.default_rng(p.seed + 123)
|
||||||
|
n = p.n_bins
|
||||||
|
|
||||||
|
ids = np.zeros(n, dtype=int)
|
||||||
|
cur = rng.integers(0, 8)
|
||||||
|
for i in range(n):
|
||||||
|
if i == 0 or rng.random() < 0.25:
|
||||||
|
cur = rng.integers(0, 8)
|
||||||
|
ids[i] = cur
|
||||||
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
def _axes_only(ax: plt.Axes) -> None:
|
||||||
|
"""Keep spines (axes lines), remove all ticks/labels/words."""
|
||||||
|
# No labels
|
||||||
|
ax.set_xlabel("")
|
||||||
|
ax.set_ylabel("")
|
||||||
|
ax.set_title("")
|
||||||
|
|
||||||
|
# Keep spines as the only axes element
|
||||||
|
for side in ("top", "right", "bottom", "left"):
|
||||||
|
ax.spines[side].set_visible(True)
|
||||||
|
|
||||||
|
# Remove tick marks and tick labels entirely
|
||||||
|
ax.set_xticks([])
|
||||||
|
ax.set_yticks([])
|
||||||
|
ax.tick_params(
|
||||||
|
axis="both",
|
||||||
|
which="both",
|
||||||
|
bottom=False,
|
||||||
|
left=False,
|
||||||
|
top=False,
|
||||||
|
right=False,
|
||||||
|
labelbottom=False,
|
||||||
|
labelleft=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# No grid
|
||||||
|
ax.grid(False)
|
||||||
|
|
||||||
|
|
||||||
|
def draw_transformer_lower_half_svg(out_path: Path, p: Params) -> None:
|
||||||
|
fig = plt.figure(figsize=(p.width_in, p.height_in), dpi=200)
|
||||||
|
fig.patch.set_alpha(0.0)
|
||||||
|
|
||||||
|
# Two axes sharing x (top curves, bottom bars)
|
||||||
|
ax_curves = fig.add_axes([0.10, 0.38, 0.86, 0.56])
|
||||||
|
ax_bars = fig.add_axes([0.10, 0.14, 0.86, 0.18], sharex=ax_curves)
|
||||||
|
|
||||||
|
ax_curves.patch.set_alpha(0.0)
|
||||||
|
ax_bars.patch.set_alpha(0.0)
|
||||||
|
|
||||||
|
# Data
|
||||||
|
t, Y = make_continuous_curves(p)
|
||||||
|
ids = make_discrete_bars(p)
|
||||||
|
|
||||||
|
# Top: continuous curves
|
||||||
|
curve_colors = ["#1f77b4", "#ff7f0e", "#2ca02c"] # blue / orange / green
|
||||||
|
for i in range(Y.shape[0]):
|
||||||
|
ax_curves.plot(t, Y[i], linewidth=p.curve_lw, color=curve_colors[i % len(curve_colors)])
|
||||||
|
|
||||||
|
ymin, ymax = float(Y.min()), float(Y.max())
|
||||||
|
ypad = 0.10 * (ymax - ymin + 1e-9)
|
||||||
|
ax_curves.set_xlim(t[0], t[-1])
|
||||||
|
ax_curves.set_ylim(ymin - ypad, ymax + ypad)
|
||||||
|
|
||||||
|
# Bottom: discrete bars (colored strip)
|
||||||
|
bar_palette = [
|
||||||
|
"#e41a1c", "#377eb8", "#4daf4a", "#984ea3",
|
||||||
|
"#ff7f00", "#ffff33", "#a65628", "#f781bf",
|
||||||
|
]
|
||||||
|
|
||||||
|
x0, x1 = t[0], t[-1]
|
||||||
|
total = x1 - x0
|
||||||
|
n = len(ids)
|
||||||
|
bin_w = total / n
|
||||||
|
gap = p.bar_gap * bin_w
|
||||||
|
|
||||||
|
ax_bars.set_xlim(x0, x1)
|
||||||
|
ax_bars.set_ylim(0, 1)
|
||||||
|
|
||||||
|
bar_y = (1 - p.bar_height) / 2
|
||||||
|
for i, cat in enumerate(ids):
|
||||||
|
left = x0 + i * bin_w + gap / 2
|
||||||
|
width = bin_w - gap
|
||||||
|
color = bar_palette[int(cat) % len(bar_palette)]
|
||||||
|
ax_bars.add_patch(Rectangle((left, bar_y), width, p.bar_height, facecolor=color, edgecolor="none"))
|
||||||
|
|
||||||
|
# Apply "axes only" styling (no numbers/words)
|
||||||
|
_axes_only(ax_curves)
|
||||||
|
_axes_only(ax_bars)
|
||||||
|
|
||||||
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fig.savefig(out_path, format="svg", transparent=True, bbox_inches="tight", pad_inches=0.0)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
ap = argparse.ArgumentParser()
|
||||||
|
ap.add_argument("--out", type=Path, default=Path("transformer_lower_half_axes_only.svg"))
|
||||||
|
ap.add_argument("--seed", type=int, default=7)
|
||||||
|
ap.add_argument("--seconds", type=float, default=10.0)
|
||||||
|
ap.add_argument("--fs", type=int, default=300)
|
||||||
|
ap.add_argument("--bins", type=int, default=40)
|
||||||
|
ap.add_argument("--curves", type=int, default=3)
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
p = Params(seed=args.seed, seconds=args.seconds, fs=args.fs, n_bins=args.bins, n_curves=args.curves)
|
||||||
|
draw_transformer_lower_half_svg(args.out, p)
|
||||||
|
print(f"Wrote: {args.out}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
161
arxiv-style/fig-scripts/gen_noise_ddmp.py
Normal file
161
arxiv-style/fig-scripts/gen_noise_ddmp.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Generate "Noisy Residual" and "Denoised Residual" curves as SVGs.
|
||||||
|
|
||||||
|
- Produces TWO separate SVG files:
|
||||||
|
noisy_residual.svg
|
||||||
|
denoised_residual.svg
|
||||||
|
|
||||||
|
- Curves are synthetic but shaped like residual noise + denoised residual.
|
||||||
|
- Uses only matplotlib + numpy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CurveParams:
|
||||||
|
seconds: float = 12.0 # length of the signal
|
||||||
|
fs: int = 250 # samples per second
|
||||||
|
seed: int = 7 # RNG seed for reproducibility
|
||||||
|
base_amp: float = 0.12 # smooth baseline amplitude
|
||||||
|
noise_amp: float = 0.55 # high-frequency noise amplitude
|
||||||
|
burst_amp: float = 1.2 # occasional spike amplitude
|
||||||
|
burst_rate_hz: float = 0.35 # average spike frequency
|
||||||
|
denoise_smooth_ms: float = 120 # smoothing window for "denoised" (ms)
|
||||||
|
|
||||||
|
|
||||||
|
def gaussian_smooth(x: np.ndarray, sigma_samples: float) -> np.ndarray:
|
||||||
|
"""Gaussian smoothing using explicit kernel convolution (no SciPy dependency)."""
|
||||||
|
if sigma_samples <= 0:
|
||||||
|
return x.copy()
|
||||||
|
|
||||||
|
radius = int(np.ceil(4 * sigma_samples))
|
||||||
|
k = np.arange(-radius, radius + 1, dtype=float)
|
||||||
|
kernel = np.exp(-(k**2) / (2 * sigma_samples**2))
|
||||||
|
kernel /= kernel.sum()
|
||||||
|
return np.convolve(x, kernel, mode="same")
|
||||||
|
|
||||||
|
|
||||||
|
def make_residual(params: CurveParams) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Create synthetic residual:
|
||||||
|
- baseline: smooth wavy trend + slight drift
|
||||||
|
- noise: band-limited-ish high-frequency noise
|
||||||
|
- bursts: sparse spikes / impulse-like events
|
||||||
|
Returns: (t, noisy, denoised)
|
||||||
|
"""
|
||||||
|
rng = np.random.default_rng(params.seed)
|
||||||
|
n = int(params.seconds * params.fs)
|
||||||
|
t = np.linspace(0, params.seconds, n, endpoint=False)
|
||||||
|
|
||||||
|
# Smooth baseline (small): combination of sinusoids + small random drift
|
||||||
|
baseline = (
|
||||||
|
0.7 * np.sin(2 * np.pi * 0.35 * t + 0.2)
|
||||||
|
+ 0.35 * np.sin(2 * np.pi * 0.9 * t + 1.2)
|
||||||
|
+ 0.25 * np.sin(2 * np.pi * 0.15 * t + 2.0)
|
||||||
|
)
|
||||||
|
baseline *= params.base_amp
|
||||||
|
drift = np.cumsum(rng.normal(0, 1, size=n))
|
||||||
|
drift = drift / (np.max(np.abs(drift)) + 1e-9) * (params.base_amp * 0.25)
|
||||||
|
baseline = baseline + drift
|
||||||
|
|
||||||
|
# High-frequency noise: whitened then lightly smoothed to look "oscillatory"
|
||||||
|
raw = rng.normal(0, 1, size=n)
|
||||||
|
hf = raw - gaussian_smooth(raw, sigma_samples=params.fs * 0.03) # remove slow part
|
||||||
|
hf = hf / (np.std(hf) + 1e-9)
|
||||||
|
hf *= params.noise_amp
|
||||||
|
|
||||||
|
# Bursts/spikes: Poisson process impulses convolved with short kernel
|
||||||
|
expected_bursts = params.burst_rate_hz * params.seconds
|
||||||
|
k_bursts = rng.poisson(expected_bursts)
|
||||||
|
impulses = np.zeros(n)
|
||||||
|
if k_bursts > 0:
|
||||||
|
idx = rng.integers(0, n, size=k_bursts)
|
||||||
|
impulses[idx] = rng.normal(loc=1.0, scale=0.4, size=k_bursts)
|
||||||
|
|
||||||
|
# Shape impulses into spikes (asymmetric bump)
|
||||||
|
spike_kernel_len = int(params.fs * 0.06) # ~60ms
|
||||||
|
spike_kernel_len = max(spike_kernel_len, 7)
|
||||||
|
spike_t = np.arange(spike_kernel_len)
|
||||||
|
spike_kernel = np.exp(-spike_t / (params.fs * 0.012)) # fast decay
|
||||||
|
spike_kernel *= np.hanning(spike_kernel_len) # taper
|
||||||
|
spike_kernel /= (spike_kernel.max() + 1e-9)
|
||||||
|
|
||||||
|
bursts = np.convolve(impulses, spike_kernel, mode="same")
|
||||||
|
bursts *= params.burst_amp
|
||||||
|
|
||||||
|
noisy = baseline + hf + bursts
|
||||||
|
|
||||||
|
# "Denoised": remove high-frequency using Gaussian smoothing,
|
||||||
|
# but keep spike structures partially.
|
||||||
|
smooth_sigma = (params.denoise_smooth_ms / 1000.0) * params.fs / 3.0
|
||||||
|
denoised = gaussian_smooth(noisy, sigma_samples=smooth_sigma)
|
||||||
|
|
||||||
|
return t, noisy, denoised
|
||||||
|
|
||||||
|
|
||||||
|
def save_curve_svg(
|
||||||
|
t: np.ndarray,
|
||||||
|
y: np.ndarray,
|
||||||
|
out_path: Path,
|
||||||
|
*,
|
||||||
|
width_in: float = 5.4,
|
||||||
|
height_in: float = 1.6,
|
||||||
|
lw: float = 2.2,
|
||||||
|
pad: float = 0.03,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Save a clean, figure-only SVG suitable for embedding in diagrams.
|
||||||
|
- No axes, ticks, labels.
|
||||||
|
- Tight bounding box.
|
||||||
|
"""
|
||||||
|
fig = plt.figure(figsize=(width_in, height_in), dpi=200)
|
||||||
|
ax = fig.add_axes([pad, pad, 1 - 2 * pad, 1 - 2 * pad])
|
||||||
|
|
||||||
|
ax.plot(t, y, linewidth=lw)
|
||||||
|
|
||||||
|
# Make it "icon-like" for diagrams: no axes or frames
|
||||||
|
ax.set_axis_off()
|
||||||
|
|
||||||
|
# Ensure bounds include a little padding
|
||||||
|
ymin, ymax = np.min(y), np.max(y)
|
||||||
|
ypad = 0.08 * (ymax - ymin + 1e-9)
|
||||||
|
ax.set_xlim(t[0], t[-1])
|
||||||
|
ax.set_ylim(ymin - ypad, ymax + ypad)
|
||||||
|
|
||||||
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fig.savefig(out_path, format="svg", bbox_inches="tight", pad_inches=0.0)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
ap = argparse.ArgumentParser()
|
||||||
|
ap.add_argument("--outdir", type=Path, default=Path("."), help="Output directory")
|
||||||
|
ap.add_argument("--seed", type=int, default=7, help="RNG seed")
|
||||||
|
ap.add_argument("--seconds", type=float, default=12.0, help="Signal length (s)")
|
||||||
|
ap.add_argument("--fs", type=int, default=250, help="Sampling rate (Hz)")
|
||||||
|
ap.add_argument("--prefix", type=str, default="", help="Filename prefix (optional)")
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
params = CurveParams(seconds=args.seconds, fs=args.fs, seed=args.seed)
|
||||||
|
t, noisy, denoised = make_residual(params)
|
||||||
|
|
||||||
|
noisy_path = args.outdir / f"{args.prefix}noisy_residual.svg"
|
||||||
|
den_path = args.outdir / f"{args.prefix}denoised_residual.svg"
|
||||||
|
|
||||||
|
save_curve_svg(t, noisy, noisy_path)
|
||||||
|
save_curve_svg(t, denoised, den_path)
|
||||||
|
|
||||||
|
print(f"Wrote:\n {noisy_path}\n {den_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
188
arxiv-style/fig-scripts/make_ddpm_like_svg.py
Normal file
188
arxiv-style/fig-scripts/make_ddpm_like_svg.py
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
DDPM-like residual curve SVGs (separate files, fixed colors):
|
||||||
|
- noisy_residual.svg (blue)
|
||||||
|
- denoised_residual.svg (purple)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DDPMStyleParams:
|
||||||
|
seconds: float = 12.0
|
||||||
|
fs: int = 250
|
||||||
|
seed: int = 7
|
||||||
|
|
||||||
|
baseline_amp: float = 0.10
|
||||||
|
mid_wiggle_amp: float = 0.18
|
||||||
|
colored_noise_amp: float = 0.65
|
||||||
|
colored_alpha: float = 1.0
|
||||||
|
|
||||||
|
burst_rate_hz: float = 0.30
|
||||||
|
burst_amp: float = 0.9
|
||||||
|
burst_width_ms: float = 55
|
||||||
|
|
||||||
|
denoise_sigmas_ms: tuple[float, ...] = (25, 60, 140)
|
||||||
|
denoise_weights: tuple[float, ...] = (0.25, 0.35, 0.40)
|
||||||
|
denoise_texture_keep: float = 0.10
|
||||||
|
|
||||||
|
|
||||||
|
def gaussian_smooth(x: np.ndarray, sigma_samples: float) -> np.ndarray:
|
||||||
|
if sigma_samples <= 0:
|
||||||
|
return x.copy()
|
||||||
|
radius = int(np.ceil(4 * sigma_samples))
|
||||||
|
k = np.arange(-radius, radius + 1, dtype=float)
|
||||||
|
kernel = np.exp(-(k**2) / (2 * sigma_samples**2))
|
||||||
|
kernel /= kernel.sum()
|
||||||
|
return np.convolve(x, kernel, mode="same")
|
||||||
|
|
||||||
|
|
||||||
|
def colored_noise_1_f(n: int, rng: np.random.Generator, alpha: float) -> np.ndarray:
|
||||||
|
white = rng.normal(0, 1, size=n)
|
||||||
|
spec = np.fft.rfft(white)
|
||||||
|
|
||||||
|
freqs = np.fft.rfftfreq(n, d=1.0)
|
||||||
|
scale = np.ones_like(freqs)
|
||||||
|
nonzero = freqs > 0
|
||||||
|
scale[nonzero] = 1.0 / (freqs[nonzero] ** (alpha / 2.0))
|
||||||
|
|
||||||
|
spec *= scale
|
||||||
|
x = np.fft.irfft(spec, n=n)
|
||||||
|
|
||||||
|
x = x - np.mean(x)
|
||||||
|
x = x / (np.std(x) + 1e-9)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def make_ddpm_like_residual(p: DDPMStyleParams) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
rng = np.random.default_rng(p.seed)
|
||||||
|
n = int(p.seconds * p.fs)
|
||||||
|
t = np.linspace(0, p.seconds, n, endpoint=False)
|
||||||
|
|
||||||
|
baseline = (
|
||||||
|
0.8 * np.sin(2 * np.pi * 0.18 * t + 0.4)
|
||||||
|
+ 0.35 * np.sin(2 * np.pi * 0.06 * t + 2.2)
|
||||||
|
) * p.baseline_amp
|
||||||
|
|
||||||
|
mid = (
|
||||||
|
0.9 * np.sin(2 * np.pi * 0.9 * t + 1.1)
|
||||||
|
+ 0.5 * np.sin(2 * np.pi * 1.6 * t + 0.2)
|
||||||
|
+ 0.3 * np.sin(2 * np.pi * 2.4 * t + 2.6)
|
||||||
|
) * p.mid_wiggle_amp
|
||||||
|
|
||||||
|
col = colored_noise_1_f(n, rng, alpha=p.colored_alpha) * p.colored_noise_amp
|
||||||
|
|
||||||
|
expected = p.burst_rate_hz * p.seconds
|
||||||
|
k = rng.poisson(expected)
|
||||||
|
impulses = np.zeros(n)
|
||||||
|
if k > 0:
|
||||||
|
idx = rng.integers(0, n, size=k)
|
||||||
|
impulses[idx] = rng.normal(loc=1.0, scale=0.35, size=k)
|
||||||
|
|
||||||
|
width = max(int(p.fs * (p.burst_width_ms / 1000.0)), 7)
|
||||||
|
u = np.arange(width)
|
||||||
|
kernel = np.exp(-u / (p.fs * 0.012)) * np.hanning(width)
|
||||||
|
kernel /= (kernel.max() + 1e-9)
|
||||||
|
bursts = np.convolve(impulses, kernel, mode="same") * p.burst_amp
|
||||||
|
|
||||||
|
noisy = baseline + mid + col + bursts
|
||||||
|
|
||||||
|
sigmas_samples = [(ms / 1000.0) * p.fs / 3.0 for ms in p.denoise_sigmas_ms]
|
||||||
|
smooths = [gaussian_smooth(noisy, s) for s in sigmas_samples]
|
||||||
|
|
||||||
|
den_base = np.zeros_like(noisy)
|
||||||
|
for w, sm in zip(p.denoise_weights, smooths):
|
||||||
|
den_base += w * sm
|
||||||
|
|
||||||
|
hf = noisy - gaussian_smooth(noisy, sigma_samples=p.fs * 0.03)
|
||||||
|
denoised = den_base + p.denoise_texture_keep * (hf / (np.std(hf) + 1e-9)) * (0.10 * np.std(den_base))
|
||||||
|
|
||||||
|
return t, noisy, denoised
|
||||||
|
|
||||||
|
|
||||||
|
def save_single_curve_svg(
|
||||||
|
t: np.ndarray,
|
||||||
|
y: np.ndarray,
|
||||||
|
out_path: Path,
|
||||||
|
*,
|
||||||
|
color: str,
|
||||||
|
lw: float = 2.2,
|
||||||
|
) -> None:
|
||||||
|
fig = plt.figure(figsize=(5.4, 1.6), dpi=200)
|
||||||
|
|
||||||
|
# Make figure background transparent
|
||||||
|
fig.patch.set_alpha(0.0)
|
||||||
|
|
||||||
|
ax = fig.add_axes([0.03, 0.03, 0.94, 0.94])
|
||||||
|
|
||||||
|
# Make axes background transparent
|
||||||
|
ax.patch.set_alpha(0.0)
|
||||||
|
|
||||||
|
ax.plot(t, y, linewidth=lw, color=color)
|
||||||
|
|
||||||
|
# clean, diagram-friendly
|
||||||
|
ax.set_axis_off()
|
||||||
|
ymin, ymax = np.min(y), np.max(y)
|
||||||
|
ypad = 0.08 * (ymax - ymin + 1e-9)
|
||||||
|
ax.set_xlim(t[0], t[-1])
|
||||||
|
ax.set_ylim(ymin - ypad, ymax + ypad)
|
||||||
|
|
||||||
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fig.savefig(
|
||||||
|
out_path,
|
||||||
|
format="svg",
|
||||||
|
bbox_inches="tight",
|
||||||
|
pad_inches=0.0,
|
||||||
|
transparent=True, # <-- key for transparent output
|
||||||
|
)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
ap = argparse.ArgumentParser()
|
||||||
|
ap.add_argument("--outdir", type=Path, default=Path("."))
|
||||||
|
ap.add_argument("--seed", type=int, default=7)
|
||||||
|
ap.add_argument("--seconds", type=float, default=12.0)
|
||||||
|
ap.add_argument("--fs", type=int, default=250)
|
||||||
|
|
||||||
|
ap.add_argument("--alpha", type=float, default=1.0)
|
||||||
|
ap.add_argument("--noise-amp", type=float, default=0.65)
|
||||||
|
ap.add_argument("--texture-keep", type=float, default=0.10)
|
||||||
|
|
||||||
|
ap.add_argument("--prefix", type=str, default="")
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
p = DDPMStyleParams(
|
||||||
|
seconds=args.seconds,
|
||||||
|
fs=args.fs,
|
||||||
|
seed=args.seed,
|
||||||
|
colored_alpha=args.alpha,
|
||||||
|
colored_noise_amp=args.noise_amp,
|
||||||
|
denoise_texture_keep=args.texture_keep,
|
||||||
|
)
|
||||||
|
|
||||||
|
t, noisy, den = make_ddpm_like_residual(p)
|
||||||
|
|
||||||
|
outdir = args.outdir
|
||||||
|
noisy_path = outdir / f"{args.prefix}noisy_residual.svg"
|
||||||
|
den_path = outdir / f"{args.prefix}denoised_residual.svg"
|
||||||
|
|
||||||
|
# Fixed colors as you requested
|
||||||
|
save_single_curve_svg(t, noisy, noisy_path, color="blue")
|
||||||
|
save_single_curve_svg(t, den, den_path, color="purple")
|
||||||
|
|
||||||
|
print("Wrote:")
|
||||||
|
print(f" {noisy_path}")
|
||||||
|
print(f" {den_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
10
arxiv-style/fig-scripts/pyproject.toml
Normal file
10
arxiv-style/fig-scripts/pyproject.toml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
[project]
|
||||||
|
name = "fig-gen-ddpm"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"numpy>=1.26",
|
||||||
|
"matplotlib>=3.8",
|
||||||
|
]
|
||||||
240
arxiv-style/fig-scripts/synth_ics_3d_waterfall.py
Normal file
240
arxiv-style/fig-scripts/synth_ics_3d_waterfall.py
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
3D "final combined outcome" (time × channel × value) with:
|
||||||
|
- NO numbers on axes (tick labels removed)
|
||||||
|
- Axis *titles* kept (texts are okay)
|
||||||
|
- Reduced whitespace: tight bbox + minimal margins
|
||||||
|
- White background (non-transparent) suitable for embedding into another SVG
|
||||||
|
|
||||||
|
Output:
|
||||||
|
default PNG, optional SVG (2D projected vectors)
|
||||||
|
|
||||||
|
Run:
|
||||||
|
uv run python synth_ics_3d_waterfall_tight.py --out ./assets/synth_ics_3d.png
|
||||||
|
uv run python synth_ics_3d_waterfall_tight.py --out ./assets/synth_ics_3d.svg --format svg
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Params:
|
||||||
|
seed: int = 7
|
||||||
|
seconds: float = 10.0
|
||||||
|
fs: int = 220
|
||||||
|
|
||||||
|
n_cont: int = 5
|
||||||
|
n_disc: int = 2
|
||||||
|
disc_vocab: int = 8
|
||||||
|
disc_change_rate_hz: float = 1.1
|
||||||
|
|
||||||
|
# view
|
||||||
|
elev: float = 25.0
|
||||||
|
azim: float = -58.0
|
||||||
|
|
||||||
|
# figure size (smaller, more "cube-like")
|
||||||
|
fig_w: float = 5.4
|
||||||
|
fig_h: float = 5.0
|
||||||
|
|
||||||
|
# discrete rendering
|
||||||
|
disc_z_scale: float = 0.45
|
||||||
|
disc_z_offset: float = -1.4
|
||||||
|
|
||||||
|
# margins (figure fraction)
|
||||||
|
left: float = 0.03
|
||||||
|
right: float = 0.99
|
||||||
|
bottom: float = 0.03
|
||||||
|
top: float = 0.99
|
||||||
|
|
||||||
|
|
||||||
|
def _smooth(x: np.ndarray, win: int) -> np.ndarray:
|
||||||
|
win = max(3, int(win) | 1)
|
||||||
|
k = np.ones(win, dtype=float)
|
||||||
|
k /= k.sum()
|
||||||
|
return np.convolve(x, k, mode="same")
|
||||||
|
|
||||||
|
|
||||||
|
def make_continuous(p: Params) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
rng = np.random.default_rng(p.seed)
|
||||||
|
T = int(p.seconds * p.fs)
|
||||||
|
t = np.linspace(0, p.seconds, T, endpoint=False)
|
||||||
|
|
||||||
|
Y = []
|
||||||
|
base_freqs = [0.08, 0.10, 0.12, 0.09, 0.11]
|
||||||
|
mid_freqs = [0.55, 0.70, 0.85, 0.62, 0.78]
|
||||||
|
|
||||||
|
for i in range(p.n_cont):
|
||||||
|
f1 = base_freqs[i % len(base_freqs)]
|
||||||
|
f2 = mid_freqs[i % len(mid_freqs)]
|
||||||
|
ph = rng.uniform(0, 2 * np.pi)
|
||||||
|
|
||||||
|
y = (
|
||||||
|
0.95 * np.sin(2 * np.pi * f1 * t + ph)
|
||||||
|
+ 0.28 * np.sin(2 * np.pi * f2 * t + 0.65 * ph)
|
||||||
|
)
|
||||||
|
|
||||||
|
bumps = np.zeros_like(t)
|
||||||
|
for _ in range(rng.integers(2, 4)):
|
||||||
|
mu = rng.uniform(0.8, p.seconds - 0.8)
|
||||||
|
sig = rng.uniform(0.25, 0.80)
|
||||||
|
bumps += np.exp(-0.5 * ((t - mu) / (sig + 1e-9)) ** 2)
|
||||||
|
y += 0.55 * bumps
|
||||||
|
|
||||||
|
noise = _smooth(rng.normal(0, 1, size=T), win=int(p.fs * 0.05))
|
||||||
|
y += 0.10 * noise
|
||||||
|
|
||||||
|
y = (y - y.mean()) / (y.std() + 1e-9)
|
||||||
|
Y.append(y)
|
||||||
|
|
||||||
|
return t, np.vstack(Y) # (n_cont, T)
|
||||||
|
|
||||||
|
|
||||||
|
def make_discrete(p: Params, t: np.ndarray) -> np.ndarray:
|
||||||
|
rng = np.random.default_rng(p.seed + 123)
|
||||||
|
T = len(t)
|
||||||
|
|
||||||
|
expected_changes = max(1, int(p.seconds * p.disc_change_rate_hz))
|
||||||
|
X = np.zeros((p.n_disc, T), dtype=int)
|
||||||
|
|
||||||
|
for c in range(p.n_disc):
|
||||||
|
k = rng.poisson(expected_changes) + 1
|
||||||
|
pts = np.unique(rng.integers(0, T, size=k))
|
||||||
|
pts = np.sort(np.concatenate([[0], pts, [T]]))
|
||||||
|
|
||||||
|
cur = rng.integers(0, p.disc_vocab)
|
||||||
|
for a, b in zip(pts[:-1], pts[1:]):
|
||||||
|
if a != 0 and rng.random() < 0.85:
|
||||||
|
cur = rng.integers(0, p.disc_vocab)
|
||||||
|
X[c, a:b] = cur
|
||||||
|
|
||||||
|
return X
|
||||||
|
|
||||||
|
|
||||||
|
def style_3d_axes(ax):
|
||||||
|
# Make panes white but less visually heavy
|
||||||
|
try:
|
||||||
|
# Keep pane fill ON (white background) but reduce edge prominence
|
||||||
|
ax.xaxis.pane.set_edgecolor("0.7")
|
||||||
|
ax.yaxis.pane.set_edgecolor("0.7")
|
||||||
|
ax.zaxis.pane.set_edgecolor("0.7")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
ax.grid(True, linewidth=0.4, alpha=0.30)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_tick_numbers_keep_axis_titles(ax):
|
||||||
|
# Remove tick labels (numbers) and tick marks, keep axis titles
|
||||||
|
ax.set_xticklabels([])
|
||||||
|
ax.set_yticklabels([])
|
||||||
|
ax.set_zticklabels([])
|
||||||
|
|
||||||
|
ax.tick_params(
|
||||||
|
axis="both",
|
||||||
|
which="both",
|
||||||
|
length=0, # no tick marks
|
||||||
|
pad=0,
|
||||||
|
)
|
||||||
|
# 3D has separate tick_params for z on some versions; this still works broadly:
|
||||||
|
try:
|
||||||
|
ax.zaxis.set_tick_params(length=0, pad=0)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
ap = argparse.ArgumentParser()
|
||||||
|
ap.add_argument("--out", type=Path, default=Path("synth_ics_3d.png"))
|
||||||
|
ap.add_argument("--format", choices=["png", "svg"], default="png")
|
||||||
|
|
||||||
|
ap.add_argument("--seed", type=int, default=7)
|
||||||
|
ap.add_argument("--seconds", type=float, default=10.0)
|
||||||
|
ap.add_argument("--fs", type=int, default=220)
|
||||||
|
|
||||||
|
ap.add_argument("--n-cont", type=int, default=5)
|
||||||
|
ap.add_argument("--n-disc", type=int, default=2)
|
||||||
|
ap.add_argument("--disc-vocab", type=int, default=8)
|
||||||
|
ap.add_argument("--disc-rate", type=float, default=1.1)
|
||||||
|
|
||||||
|
ap.add_argument("--elev", type=float, default=25.0)
|
||||||
|
ap.add_argument("--azim", type=float, default=-58.0)
|
||||||
|
|
||||||
|
ap.add_argument("--fig-w", type=float, default=5.4)
|
||||||
|
ap.add_argument("--fig-h", type=float, default=5.0)
|
||||||
|
|
||||||
|
ap.add_argument("--disc-z-scale", type=float, default=0.45)
|
||||||
|
ap.add_argument("--disc-z-offset", type=float, default=-1.4)
|
||||||
|
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
p = Params(
|
||||||
|
seed=args.seed,
|
||||||
|
seconds=args.seconds,
|
||||||
|
fs=args.fs,
|
||||||
|
n_cont=args.n_cont,
|
||||||
|
n_disc=args.n_disc,
|
||||||
|
disc_vocab=args.disc_vocab,
|
||||||
|
disc_change_rate_hz=args.disc_rate,
|
||||||
|
elev=args.elev,
|
||||||
|
azim=args.azim,
|
||||||
|
fig_w=args.fig_w,
|
||||||
|
fig_h=args.fig_h,
|
||||||
|
disc_z_scale=args.disc_z_scale,
|
||||||
|
disc_z_offset=args.disc_z_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
t, Yc = make_continuous(p)
|
||||||
|
Xd = make_discrete(p, t)
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=(p.fig_w, p.fig_h), dpi=220, facecolor="white")
|
||||||
|
ax = fig.add_subplot(111, projection="3d")
|
||||||
|
style_3d_axes(ax)
|
||||||
|
|
||||||
|
# Reduce whitespace around axes (tight placement)
|
||||||
|
fig.subplots_adjust(left=p.left, right=p.right, bottom=p.bottom, top=p.top)
|
||||||
|
|
||||||
|
# Draw continuous channels
|
||||||
|
for i in range(p.n_cont):
|
||||||
|
y = np.full_like(t, fill_value=i, dtype=float)
|
||||||
|
z = Yc[i]
|
||||||
|
ax.plot(t, y, z, linewidth=2.0)
|
||||||
|
|
||||||
|
# Draw discrete channels as steps
|
||||||
|
for j in range(p.n_disc):
|
||||||
|
ch = p.n_cont + j
|
||||||
|
y = np.full_like(t, fill_value=ch, dtype=float)
|
||||||
|
z = p.disc_z_offset + p.disc_z_scale * Xd[j].astype(float)
|
||||||
|
ax.step(t, y, z, where="post", linewidth=2.2)
|
||||||
|
|
||||||
|
# Axis titles kept
|
||||||
|
ax.set_xlabel("time")
|
||||||
|
ax.set_ylabel("channel")
|
||||||
|
ax.set_zlabel("value")
|
||||||
|
|
||||||
|
# Remove numeric tick labels + tick marks
|
||||||
|
remove_tick_numbers_keep_axis_titles(ax)
|
||||||
|
|
||||||
|
# Camera
|
||||||
|
ax.view_init(elev=p.elev, azim=p.azim)
|
||||||
|
|
||||||
|
# Save tightly (minimize white border)
|
||||||
|
args.out.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
save_kwargs = dict(bbox_inches="tight", pad_inches=0.03, facecolor="white")
|
||||||
|
if args.format == "svg" or args.out.suffix.lower() == ".svg":
|
||||||
|
fig.savefig(args.out, format="svg", **save_kwargs)
|
||||||
|
else:
|
||||||
|
fig.savefig(args.out, format="png", **save_kwargs)
|
||||||
|
|
||||||
|
plt.close(fig)
|
||||||
|
print(f"Wrote: {args.out}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
262
arxiv-style/fig-scripts/transformer_math_figure.py
Normal file
262
arxiv-style/fig-scripts/transformer_math_figure.py
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Transformer-ish "trend" visuals with NO equations:
|
||||||
|
- attention_weights.svg : heatmap-like attention map (looks like "Transformer attends to positions")
|
||||||
|
- token_activation_trends.svg: multiple token-channel curves (continuous trends)
|
||||||
|
- discrete_tokens.svg : step-like discrete channel trends (optional)
|
||||||
|
|
||||||
|
All SVGs have transparent background and no axes (diagram-friendly).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Synthetic data generators
|
||||||
|
# ----------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Params:
|
||||||
|
seed: int = 7
|
||||||
|
T: int = 24 # sequence length (positions)
|
||||||
|
n_heads: int = 4 # attention heads to blend/choose
|
||||||
|
n_curves: int = 7 # curves in token_activation_trends
|
||||||
|
seconds: float = 10.0
|
||||||
|
fs: int = 200
|
||||||
|
|
||||||
|
|
||||||
|
def _gaussian(x: np.ndarray, mu: float, sig: float) -> np.ndarray:
|
||||||
|
return np.exp(-0.5 * ((x - mu) / (sig + 1e-9)) ** 2)
|
||||||
|
|
||||||
|
|
||||||
|
def make_attention_map(T: int, rng: np.random.Generator, mode: str) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Create a transformer-like attention weight matrix A (T x T) with different visual styles:
|
||||||
|
- "local": mostly near-diagonal attention
|
||||||
|
- "global": some global tokens attend broadly
|
||||||
|
- "causal": lower-triangular (decoder-like) with local preference
|
||||||
|
"""
|
||||||
|
i = np.arange(T)[:, None] # query positions
|
||||||
|
j = np.arange(T)[None, :] # key positions
|
||||||
|
|
||||||
|
if mode == "local":
|
||||||
|
logits = -((i - j) ** 2) / (2 * (2.2 ** 2))
|
||||||
|
logits += 0.15 * rng.normal(size=(T, T))
|
||||||
|
elif mode == "global":
|
||||||
|
logits = -((i - j) ** 2) / (2 * (3.0 ** 2))
|
||||||
|
# Add a few "global" key positions that many queries attend to
|
||||||
|
globals_ = rng.choice(T, size=max(2, T // 10), replace=False)
|
||||||
|
for g in globals_:
|
||||||
|
logits += 1.2 * _gaussian(j, mu=g, sig=1.0)
|
||||||
|
logits += 0.12 * rng.normal(size=(T, T))
|
||||||
|
elif mode == "causal":
|
||||||
|
logits = -((i - j) ** 2) / (2 * (2.0 ** 2))
|
||||||
|
logits += 0.12 * rng.normal(size=(T, T))
|
||||||
|
logits = np.where(j <= i, logits, -1e9) # mask future
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown attention mode: {mode}")
|
||||||
|
|
||||||
|
# softmax rows
|
||||||
|
logits = logits - np.max(logits, axis=1, keepdims=True)
|
||||||
|
A = np.exp(logits)
|
||||||
|
A /= (np.sum(A, axis=1, keepdims=True) + 1e-9)
|
||||||
|
return A
|
||||||
|
|
||||||
|
|
||||||
|
def make_token_activation_trends(p: Params) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Multiple smooth curves that feel like "representations evolving across layers/time".
|
||||||
|
Returns:
|
||||||
|
t: (N,)
|
||||||
|
Y: (n_curves, N)
|
||||||
|
"""
|
||||||
|
rng = np.random.default_rng(p.seed)
|
||||||
|
N = int(p.seconds * p.fs)
|
||||||
|
t = np.linspace(0, p.seconds, N, endpoint=False)
|
||||||
|
|
||||||
|
Y = []
|
||||||
|
for k in range(p.n_curves):
|
||||||
|
# Multi-scale smooth components + some bursty response
|
||||||
|
f1 = 0.10 + 0.04 * k
|
||||||
|
f2 = 0.60 + 0.18 * (k % 3)
|
||||||
|
phase = rng.uniform(0, 2 * np.pi)
|
||||||
|
|
||||||
|
base = 0.9 * np.sin(2 * np.pi * f1 * t + phase) + 0.35 * np.sin(2 * np.pi * f2 * t + 0.7 * phase)
|
||||||
|
|
||||||
|
# "attention-like gating": a few bumps where the curve spikes smoothly
|
||||||
|
bumps = np.zeros_like(t)
|
||||||
|
for _ in range(rng.integers(2, 5)):
|
||||||
|
mu = rng.uniform(0.5, p.seconds - 0.5)
|
||||||
|
sig = rng.uniform(0.15, 0.55)
|
||||||
|
bumps += 0.9 * _gaussian(t, mu=mu, sig=sig)
|
||||||
|
|
||||||
|
noise = rng.normal(0, 1, size=N)
|
||||||
|
noise = np.convolve(noise, np.ones(11) / 11.0, mode="same") # smooth noise
|
||||||
|
|
||||||
|
y = base + 0.85 * bumps + 0.12 * noise
|
||||||
|
|
||||||
|
# normalize and vertically offset
|
||||||
|
y = (y - y.mean()) / (y.std() + 1e-9)
|
||||||
|
y = 0.75 * y + 0.18 * k
|
||||||
|
Y.append(y)
|
||||||
|
|
||||||
|
return t, np.vstack(Y)
|
||||||
|
|
||||||
|
|
||||||
|
def make_discrete_trends(p: Params, vocab: int = 9, change_rate_hz: float = 1.3) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Discrete step-like channels: useful if you want a "token-id / discrete feature" feel.
|
||||||
|
Returns:
|
||||||
|
t: (N,)
|
||||||
|
X: (n_curves, N) integers
|
||||||
|
"""
|
||||||
|
rng = np.random.default_rng(p.seed + 123)
|
||||||
|
N = int(p.seconds * p.fs)
|
||||||
|
t = np.linspace(0, p.seconds, N, endpoint=False)
|
||||||
|
|
||||||
|
expected = max(1, int(p.seconds * change_rate_hz))
|
||||||
|
X = np.zeros((p.n_curves, N), dtype=int)
|
||||||
|
for c in range(p.n_curves):
|
||||||
|
k = rng.poisson(expected) + 1
|
||||||
|
pts = np.unique(rng.integers(0, N, size=k))
|
||||||
|
pts = np.sort(np.concatenate([[0], pts, [N]]))
|
||||||
|
|
||||||
|
cur = rng.integers(0, vocab)
|
||||||
|
for a, b in zip(pts[:-1], pts[1:]):
|
||||||
|
if a != 0 and rng.random() < 0.9:
|
||||||
|
cur = rng.integers(0, vocab)
|
||||||
|
X[c, a:b] = cur
|
||||||
|
|
||||||
|
return t, X
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Plot helpers (SVG, transparent, axes-free)
|
||||||
|
# ----------------------------
|
||||||
|
|
||||||
|
def _transparent_fig_ax(width_in: float, height_in: float):
|
||||||
|
fig = plt.figure(figsize=(width_in, height_in), dpi=200)
|
||||||
|
fig.patch.set_alpha(0.0)
|
||||||
|
ax = fig.add_axes([0.03, 0.03, 0.94, 0.94])
|
||||||
|
ax.patch.set_alpha(0.0)
|
||||||
|
ax.set_axis_off()
|
||||||
|
return fig, ax
|
||||||
|
|
||||||
|
|
||||||
|
def save_attention_svg(A: np.ndarray, out: Path, *, show_colorbar: bool = False) -> None:
|
||||||
|
fig, ax = _transparent_fig_ax(4.2, 4.2)
|
||||||
|
|
||||||
|
# Using default colormap (no explicit color specification)
|
||||||
|
im = ax.imshow(A, aspect="equal", interpolation="nearest")
|
||||||
|
|
||||||
|
if show_colorbar:
|
||||||
|
# colorbar can be useful, but it adds clutter in diagrams
|
||||||
|
cax = fig.add_axes([0.92, 0.10, 0.03, 0.80])
|
||||||
|
cb = fig.colorbar(im, cax=cax)
|
||||||
|
cb.outline.set_linewidth(1.0)
|
||||||
|
|
||||||
|
out.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fig.savefig(out, format="svg", bbox_inches="tight", pad_inches=0.0, transparent=True)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def save_multi_curve_svg(t: np.ndarray, Y: np.ndarray, out: Path, *, lw: float = 2.0) -> None:
|
||||||
|
fig, ax = _transparent_fig_ax(6.0, 2.2)
|
||||||
|
|
||||||
|
for i in range(Y.shape[0]):
|
||||||
|
ax.plot(t, Y[i], linewidth=lw)
|
||||||
|
|
||||||
|
y_all = Y.reshape(-1)
|
||||||
|
ymin, ymax = float(np.min(y_all)), float(np.max(y_all))
|
||||||
|
ypad = 0.08 * (ymax - ymin + 1e-9)
|
||||||
|
ax.set_xlim(t[0], t[-1])
|
||||||
|
ax.set_ylim(ymin - ypad, ymax + ypad)
|
||||||
|
|
||||||
|
out.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fig.savefig(out, format="svg", bbox_inches="tight", pad_inches=0.0, transparent=True)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def save_discrete_svg(t: np.ndarray, X: np.ndarray, out: Path, *, lw: float = 2.0, spacing: float = 1.25) -> None:
|
||||||
|
fig, ax = _transparent_fig_ax(6.0, 2.2)
|
||||||
|
|
||||||
|
for i in range(X.shape[0]):
|
||||||
|
y = X[i].astype(float) + i * spacing
|
||||||
|
ax.step(t, y, where="post", linewidth=lw)
|
||||||
|
|
||||||
|
y_all = (X.astype(float) + np.arange(X.shape[0])[:, None] * spacing).reshape(-1)
|
||||||
|
ymin, ymax = float(np.min(y_all)), float(np.max(y_all))
|
||||||
|
ypad = 0.10 * (ymax - ymin + 1e-9)
|
||||||
|
ax.set_xlim(t[0], t[-1])
|
||||||
|
ax.set_ylim(ymin - ypad, ymax + ypad)
|
||||||
|
|
||||||
|
out.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fig.savefig(out, format="svg", bbox_inches="tight", pad_inches=0.0, transparent=True)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# CLI
|
||||||
|
# ----------------------------
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
ap = argparse.ArgumentParser()
|
||||||
|
ap.add_argument("--outdir", type=Path, default=Path("out"))
|
||||||
|
ap.add_argument("--seed", type=int, default=7)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
ap.add_argument("--T", type=int, default=24)
|
||||||
|
ap.add_argument("--attn-mode", type=str, default="local", choices=["local", "global", "causal"])
|
||||||
|
ap.add_argument("--colorbar", action="store_true")
|
||||||
|
|
||||||
|
# curves
|
||||||
|
ap.add_argument("--seconds", type=float, default=10.0)
|
||||||
|
ap.add_argument("--fs", type=int, default=200)
|
||||||
|
ap.add_argument("--n-curves", type=int, default=7)
|
||||||
|
|
||||||
|
# discrete optional
|
||||||
|
ap.add_argument("--with-discrete", action="store_true")
|
||||||
|
ap.add_argument("--disc-vocab", type=int, default=9)
|
||||||
|
ap.add_argument("--disc-rate", type=float, default=1.3)
|
||||||
|
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
p = Params(
|
||||||
|
seed=args.seed,
|
||||||
|
T=args.T,
|
||||||
|
n_curves=args.n_curves,
|
||||||
|
seconds=args.seconds,
|
||||||
|
fs=args.fs,
|
||||||
|
)
|
||||||
|
|
||||||
|
rng = np.random.default_rng(args.seed)
|
||||||
|
|
||||||
|
# 1) attention map
|
||||||
|
A = make_attention_map(args.T, rng, mode=args.attn_mode)
|
||||||
|
save_attention_svg(A, args.outdir / "attention_weights.svg", show_colorbar=args.colorbar)
|
||||||
|
|
||||||
|
# 2) continuous trends
|
||||||
|
t, Y = make_token_activation_trends(p)
|
||||||
|
save_multi_curve_svg(t, Y, args.outdir / "token_activation_trends.svg")
|
||||||
|
|
||||||
|
# 3) discrete trends (optional)
|
||||||
|
if args.with_discrete:
|
||||||
|
td, X = make_discrete_trends(p, vocab=args.disc_vocab, change_rate_hz=args.disc_rate)
|
||||||
|
save_discrete_svg(td, X, args.outdir / "discrete_tokens.svg")
|
||||||
|
|
||||||
|
print("Wrote:")
|
||||||
|
print(f" {args.outdir / 'attention_weights.svg'}")
|
||||||
|
print(f" {args.outdir / 'token_activation_trends.svg'}")
|
||||||
|
if args.with_discrete:
|
||||||
|
print(f" {args.outdir / 'discrete_tokens.svg'}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user