forked from manbo/internal-docs
Add: python scripts for figure generation
This commit is contained in:
262
arxiv-style/fig-scripts/transformer_math_figure.py
Normal file
262
arxiv-style/fig-scripts/transformer_math_figure.py
Normal file
@@ -0,0 +1,262 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Transformer-ish "trend" visuals with NO equations:
|
||||
- attention_weights.svg : heatmap-like attention map (looks like "Transformer attends to positions")
|
||||
- token_activation_trends.svg: multiple token-channel curves (continuous trends)
|
||||
- discrete_tokens.svg : step-like discrete channel trends (optional)
|
||||
|
||||
All SVGs have transparent background and no axes (diagram-friendly).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Synthetic data generators
|
||||
# ----------------------------
|
||||
|
||||
@dataclass
|
||||
class Params:
|
||||
seed: int = 7
|
||||
T: int = 24 # sequence length (positions)
|
||||
n_heads: int = 4 # attention heads to blend/choose
|
||||
n_curves: int = 7 # curves in token_activation_trends
|
||||
seconds: float = 10.0
|
||||
fs: int = 200
|
||||
|
||||
|
||||
def _gaussian(x: np.ndarray, mu: float, sig: float) -> np.ndarray:
|
||||
return np.exp(-0.5 * ((x - mu) / (sig + 1e-9)) ** 2)
|
||||
|
||||
|
||||
def make_attention_map(T: int, rng: np.random.Generator, mode: str) -> np.ndarray:
|
||||
"""
|
||||
Create a transformer-like attention weight matrix A (T x T) with different visual styles:
|
||||
- "local": mostly near-diagonal attention
|
||||
- "global": some global tokens attend broadly
|
||||
- "causal": lower-triangular (decoder-like) with local preference
|
||||
"""
|
||||
i = np.arange(T)[:, None] # query positions
|
||||
j = np.arange(T)[None, :] # key positions
|
||||
|
||||
if mode == "local":
|
||||
logits = -((i - j) ** 2) / (2 * (2.2 ** 2))
|
||||
logits += 0.15 * rng.normal(size=(T, T))
|
||||
elif mode == "global":
|
||||
logits = -((i - j) ** 2) / (2 * (3.0 ** 2))
|
||||
# Add a few "global" key positions that many queries attend to
|
||||
globals_ = rng.choice(T, size=max(2, T // 10), replace=False)
|
||||
for g in globals_:
|
||||
logits += 1.2 * _gaussian(j, mu=g, sig=1.0)
|
||||
logits += 0.12 * rng.normal(size=(T, T))
|
||||
elif mode == "causal":
|
||||
logits = -((i - j) ** 2) / (2 * (2.0 ** 2))
|
||||
logits += 0.12 * rng.normal(size=(T, T))
|
||||
logits = np.where(j <= i, logits, -1e9) # mask future
|
||||
else:
|
||||
raise ValueError(f"Unknown attention mode: {mode}")
|
||||
|
||||
# softmax rows
|
||||
logits = logits - np.max(logits, axis=1, keepdims=True)
|
||||
A = np.exp(logits)
|
||||
A /= (np.sum(A, axis=1, keepdims=True) + 1e-9)
|
||||
return A
|
||||
|
||||
|
||||
def make_token_activation_trends(p: Params) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Multiple smooth curves that feel like "representations evolving across layers/time".
|
||||
Returns:
|
||||
t: (N,)
|
||||
Y: (n_curves, N)
|
||||
"""
|
||||
rng = np.random.default_rng(p.seed)
|
||||
N = int(p.seconds * p.fs)
|
||||
t = np.linspace(0, p.seconds, N, endpoint=False)
|
||||
|
||||
Y = []
|
||||
for k in range(p.n_curves):
|
||||
# Multi-scale smooth components + some bursty response
|
||||
f1 = 0.10 + 0.04 * k
|
||||
f2 = 0.60 + 0.18 * (k % 3)
|
||||
phase = rng.uniform(0, 2 * np.pi)
|
||||
|
||||
base = 0.9 * np.sin(2 * np.pi * f1 * t + phase) + 0.35 * np.sin(2 * np.pi * f2 * t + 0.7 * phase)
|
||||
|
||||
# "attention-like gating": a few bumps where the curve spikes smoothly
|
||||
bumps = np.zeros_like(t)
|
||||
for _ in range(rng.integers(2, 5)):
|
||||
mu = rng.uniform(0.5, p.seconds - 0.5)
|
||||
sig = rng.uniform(0.15, 0.55)
|
||||
bumps += 0.9 * _gaussian(t, mu=mu, sig=sig)
|
||||
|
||||
noise = rng.normal(0, 1, size=N)
|
||||
noise = np.convolve(noise, np.ones(11) / 11.0, mode="same") # smooth noise
|
||||
|
||||
y = base + 0.85 * bumps + 0.12 * noise
|
||||
|
||||
# normalize and vertically offset
|
||||
y = (y - y.mean()) / (y.std() + 1e-9)
|
||||
y = 0.75 * y + 0.18 * k
|
||||
Y.append(y)
|
||||
|
||||
return t, np.vstack(Y)
|
||||
|
||||
|
||||
def make_discrete_trends(p: Params, vocab: int = 9, change_rate_hz: float = 1.3) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Discrete step-like channels: useful if you want a "token-id / discrete feature" feel.
|
||||
Returns:
|
||||
t: (N,)
|
||||
X: (n_curves, N) integers
|
||||
"""
|
||||
rng = np.random.default_rng(p.seed + 123)
|
||||
N = int(p.seconds * p.fs)
|
||||
t = np.linspace(0, p.seconds, N, endpoint=False)
|
||||
|
||||
expected = max(1, int(p.seconds * change_rate_hz))
|
||||
X = np.zeros((p.n_curves, N), dtype=int)
|
||||
for c in range(p.n_curves):
|
||||
k = rng.poisson(expected) + 1
|
||||
pts = np.unique(rng.integers(0, N, size=k))
|
||||
pts = np.sort(np.concatenate([[0], pts, [N]]))
|
||||
|
||||
cur = rng.integers(0, vocab)
|
||||
for a, b in zip(pts[:-1], pts[1:]):
|
||||
if a != 0 and rng.random() < 0.9:
|
||||
cur = rng.integers(0, vocab)
|
||||
X[c, a:b] = cur
|
||||
|
||||
return t, X
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Plot helpers (SVG, transparent, axes-free)
|
||||
# ----------------------------
|
||||
|
||||
def _transparent_fig_ax(width_in: float, height_in: float):
|
||||
fig = plt.figure(figsize=(width_in, height_in), dpi=200)
|
||||
fig.patch.set_alpha(0.0)
|
||||
ax = fig.add_axes([0.03, 0.03, 0.94, 0.94])
|
||||
ax.patch.set_alpha(0.0)
|
||||
ax.set_axis_off()
|
||||
return fig, ax
|
||||
|
||||
|
||||
def save_attention_svg(A: np.ndarray, out: Path, *, show_colorbar: bool = False) -> None:
|
||||
fig, ax = _transparent_fig_ax(4.2, 4.2)
|
||||
|
||||
# Using default colormap (no explicit color specification)
|
||||
im = ax.imshow(A, aspect="equal", interpolation="nearest")
|
||||
|
||||
if show_colorbar:
|
||||
# colorbar can be useful, but it adds clutter in diagrams
|
||||
cax = fig.add_axes([0.92, 0.10, 0.03, 0.80])
|
||||
cb = fig.colorbar(im, cax=cax)
|
||||
cb.outline.set_linewidth(1.0)
|
||||
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
fig.savefig(out, format="svg", bbox_inches="tight", pad_inches=0.0, transparent=True)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def save_multi_curve_svg(t: np.ndarray, Y: np.ndarray, out: Path, *, lw: float = 2.0) -> None:
|
||||
fig, ax = _transparent_fig_ax(6.0, 2.2)
|
||||
|
||||
for i in range(Y.shape[0]):
|
||||
ax.plot(t, Y[i], linewidth=lw)
|
||||
|
||||
y_all = Y.reshape(-1)
|
||||
ymin, ymax = float(np.min(y_all)), float(np.max(y_all))
|
||||
ypad = 0.08 * (ymax - ymin + 1e-9)
|
||||
ax.set_xlim(t[0], t[-1])
|
||||
ax.set_ylim(ymin - ypad, ymax + ypad)
|
||||
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
fig.savefig(out, format="svg", bbox_inches="tight", pad_inches=0.0, transparent=True)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def save_discrete_svg(t: np.ndarray, X: np.ndarray, out: Path, *, lw: float = 2.0, spacing: float = 1.25) -> None:
|
||||
fig, ax = _transparent_fig_ax(6.0, 2.2)
|
||||
|
||||
for i in range(X.shape[0]):
|
||||
y = X[i].astype(float) + i * spacing
|
||||
ax.step(t, y, where="post", linewidth=lw)
|
||||
|
||||
y_all = (X.astype(float) + np.arange(X.shape[0])[:, None] * spacing).reshape(-1)
|
||||
ymin, ymax = float(np.min(y_all)), float(np.max(y_all))
|
||||
ypad = 0.10 * (ymax - ymin + 1e-9)
|
||||
ax.set_xlim(t[0], t[-1])
|
||||
ax.set_ylim(ymin - ypad, ymax + ypad)
|
||||
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
fig.savefig(out, format="svg", bbox_inches="tight", pad_inches=0.0, transparent=True)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# CLI
|
||||
# ----------------------------
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--outdir", type=Path, default=Path("out"))
|
||||
ap.add_argument("--seed", type=int, default=7)
|
||||
|
||||
# attention
|
||||
ap.add_argument("--T", type=int, default=24)
|
||||
ap.add_argument("--attn-mode", type=str, default="local", choices=["local", "global", "causal"])
|
||||
ap.add_argument("--colorbar", action="store_true")
|
||||
|
||||
# curves
|
||||
ap.add_argument("--seconds", type=float, default=10.0)
|
||||
ap.add_argument("--fs", type=int, default=200)
|
||||
ap.add_argument("--n-curves", type=int, default=7)
|
||||
|
||||
# discrete optional
|
||||
ap.add_argument("--with-discrete", action="store_true")
|
||||
ap.add_argument("--disc-vocab", type=int, default=9)
|
||||
ap.add_argument("--disc-rate", type=float, default=1.3)
|
||||
|
||||
args = ap.parse_args()
|
||||
|
||||
p = Params(
|
||||
seed=args.seed,
|
||||
T=args.T,
|
||||
n_curves=args.n_curves,
|
||||
seconds=args.seconds,
|
||||
fs=args.fs,
|
||||
)
|
||||
|
||||
rng = np.random.default_rng(args.seed)
|
||||
|
||||
# 1) attention map
|
||||
A = make_attention_map(args.T, rng, mode=args.attn_mode)
|
||||
save_attention_svg(A, args.outdir / "attention_weights.svg", show_colorbar=args.colorbar)
|
||||
|
||||
# 2) continuous trends
|
||||
t, Y = make_token_activation_trends(p)
|
||||
save_multi_curve_svg(t, Y, args.outdir / "token_activation_trends.svg")
|
||||
|
||||
# 3) discrete trends (optional)
|
||||
if args.with_discrete:
|
||||
td, X = make_discrete_trends(p, vocab=args.disc_vocab, change_rate_hz=args.disc_rate)
|
||||
save_discrete_svg(td, X, args.outdir / "discrete_tokens.svg")
|
||||
|
||||
print("Wrote:")
|
||||
print(f" {args.outdir / 'attention_weights.svg'}")
|
||||
print(f" {args.outdir / 'token_activation_trends.svg'}")
|
||||
if args.with_discrete:
|
||||
print(f" {args.outdir / 'discrete_tokens.svg'}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user