#!/usr/bin/env python3 """ Transformer-ish "trend" visuals with NO equations: - attention_weights.svg : heatmap-like attention map (looks like "Transformer attends to positions") - token_activation_trends.svg: multiple token-channel curves (continuous trends) - discrete_tokens.svg : step-like discrete channel trends (optional) All SVGs have transparent background and no axes (diagram-friendly). """ from __future__ import annotations import argparse from dataclasses import dataclass from pathlib import Path import numpy as np import matplotlib.pyplot as plt # ---------------------------- # Synthetic data generators # ---------------------------- @dataclass class Params: seed: int = 7 T: int = 24 # sequence length (positions) n_heads: int = 4 # attention heads to blend/choose n_curves: int = 7 # curves in token_activation_trends seconds: float = 10.0 fs: int = 200 def _gaussian(x: np.ndarray, mu: float, sig: float) -> np.ndarray: return np.exp(-0.5 * ((x - mu) / (sig + 1e-9)) ** 2) def make_attention_map(T: int, rng: np.random.Generator, mode: str) -> np.ndarray: """ Create a transformer-like attention weight matrix A (T x T) with different visual styles: - "local": mostly near-diagonal attention - "global": some global tokens attend broadly - "causal": lower-triangular (decoder-like) with local preference """ i = np.arange(T)[:, None] # query positions j = np.arange(T)[None, :] # key positions if mode == "local": logits = -((i - j) ** 2) / (2 * (2.2 ** 2)) logits += 0.15 * rng.normal(size=(T, T)) elif mode == "global": logits = -((i - j) ** 2) / (2 * (3.0 ** 2)) # Add a few "global" key positions that many queries attend to globals_ = rng.choice(T, size=max(2, T // 10), replace=False) for g in globals_: logits += 1.2 * _gaussian(j, mu=g, sig=1.0) logits += 0.12 * rng.normal(size=(T, T)) elif mode == "causal": logits = -((i - j) ** 2) / (2 * (2.0 ** 2)) logits += 0.12 * rng.normal(size=(T, T)) logits = np.where(j <= i, logits, -1e9) # mask future else: raise ValueError(f"Unknown attention mode: {mode}") # softmax rows logits = logits - np.max(logits, axis=1, keepdims=True) A = np.exp(logits) A /= (np.sum(A, axis=1, keepdims=True) + 1e-9) return A def make_token_activation_trends(p: Params) -> tuple[np.ndarray, np.ndarray]: """ Multiple smooth curves that feel like "representations evolving across layers/time". Returns: t: (N,) Y: (n_curves, N) """ rng = np.random.default_rng(p.seed) N = int(p.seconds * p.fs) t = np.linspace(0, p.seconds, N, endpoint=False) Y = [] for k in range(p.n_curves): # Multi-scale smooth components + some bursty response f1 = 0.10 + 0.04 * k f2 = 0.60 + 0.18 * (k % 3) phase = rng.uniform(0, 2 * np.pi) base = 0.9 * np.sin(2 * np.pi * f1 * t + phase) + 0.35 * np.sin(2 * np.pi * f2 * t + 0.7 * phase) # "attention-like gating": a few bumps where the curve spikes smoothly bumps = np.zeros_like(t) for _ in range(rng.integers(2, 5)): mu = rng.uniform(0.5, p.seconds - 0.5) sig = rng.uniform(0.15, 0.55) bumps += 0.9 * _gaussian(t, mu=mu, sig=sig) noise = rng.normal(0, 1, size=N) noise = np.convolve(noise, np.ones(11) / 11.0, mode="same") # smooth noise y = base + 0.85 * bumps + 0.12 * noise # normalize and vertically offset y = (y - y.mean()) / (y.std() + 1e-9) y = 0.75 * y + 0.18 * k Y.append(y) return t, np.vstack(Y) def make_discrete_trends(p: Params, vocab: int = 9, change_rate_hz: float = 1.3) -> tuple[np.ndarray, np.ndarray]: """ Discrete step-like channels: useful if you want a "token-id / discrete feature" feel. Returns: t: (N,) X: (n_curves, N) integers """ rng = np.random.default_rng(p.seed + 123) N = int(p.seconds * p.fs) t = np.linspace(0, p.seconds, N, endpoint=False) expected = max(1, int(p.seconds * change_rate_hz)) X = np.zeros((p.n_curves, N), dtype=int) for c in range(p.n_curves): k = rng.poisson(expected) + 1 pts = np.unique(rng.integers(0, N, size=k)) pts = np.sort(np.concatenate([[0], pts, [N]])) cur = rng.integers(0, vocab) for a, b in zip(pts[:-1], pts[1:]): if a != 0 and rng.random() < 0.9: cur = rng.integers(0, vocab) X[c, a:b] = cur return t, X # ---------------------------- # Plot helpers (SVG, transparent, axes-free) # ---------------------------- def _transparent_fig_ax(width_in: float, height_in: float): fig = plt.figure(figsize=(width_in, height_in), dpi=200) fig.patch.set_alpha(0.0) ax = fig.add_axes([0.03, 0.03, 0.94, 0.94]) ax.patch.set_alpha(0.0) ax.set_axis_off() return fig, ax def save_attention_svg(A: np.ndarray, out: Path, *, show_colorbar: bool = False) -> None: fig, ax = _transparent_fig_ax(4.2, 4.2) # Using default colormap (no explicit color specification) im = ax.imshow(A, aspect="equal", interpolation="nearest") if show_colorbar: # colorbar can be useful, but it adds clutter in diagrams cax = fig.add_axes([0.92, 0.10, 0.03, 0.80]) cb = fig.colorbar(im, cax=cax) cb.outline.set_linewidth(1.0) out.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out, format="svg", bbox_inches="tight", pad_inches=0.0, transparent=True) plt.close(fig) def save_multi_curve_svg(t: np.ndarray, Y: np.ndarray, out: Path, *, lw: float = 2.0) -> None: fig, ax = _transparent_fig_ax(6.0, 2.2) for i in range(Y.shape[0]): ax.plot(t, Y[i], linewidth=lw) y_all = Y.reshape(-1) ymin, ymax = float(np.min(y_all)), float(np.max(y_all)) ypad = 0.08 * (ymax - ymin + 1e-9) ax.set_xlim(t[0], t[-1]) ax.set_ylim(ymin - ypad, ymax + ypad) out.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out, format="svg", bbox_inches="tight", pad_inches=0.0, transparent=True) plt.close(fig) def save_discrete_svg(t: np.ndarray, X: np.ndarray, out: Path, *, lw: float = 2.0, spacing: float = 1.25) -> None: fig, ax = _transparent_fig_ax(6.0, 2.2) for i in range(X.shape[0]): y = X[i].astype(float) + i * spacing ax.step(t, y, where="post", linewidth=lw) y_all = (X.astype(float) + np.arange(X.shape[0])[:, None] * spacing).reshape(-1) ymin, ymax = float(np.min(y_all)), float(np.max(y_all)) ypad = 0.10 * (ymax - ymin + 1e-9) ax.set_xlim(t[0], t[-1]) ax.set_ylim(ymin - ypad, ymax + ypad) out.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out, format="svg", bbox_inches="tight", pad_inches=0.0, transparent=True) plt.close(fig) # ---------------------------- # CLI # ---------------------------- def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--outdir", type=Path, default=Path("out")) ap.add_argument("--seed", type=int, default=7) # attention ap.add_argument("--T", type=int, default=24) ap.add_argument("--attn-mode", type=str, default="local", choices=["local", "global", "causal"]) ap.add_argument("--colorbar", action="store_true") # curves ap.add_argument("--seconds", type=float, default=10.0) ap.add_argument("--fs", type=int, default=200) ap.add_argument("--n-curves", type=int, default=7) # discrete optional ap.add_argument("--with-discrete", action="store_true") ap.add_argument("--disc-vocab", type=int, default=9) ap.add_argument("--disc-rate", type=float, default=1.3) args = ap.parse_args() p = Params( seed=args.seed, T=args.T, n_curves=args.n_curves, seconds=args.seconds, fs=args.fs, ) rng = np.random.default_rng(args.seed) # 1) attention map A = make_attention_map(args.T, rng, mode=args.attn_mode) save_attention_svg(A, args.outdir / "attention_weights.svg", show_colorbar=args.colorbar) # 2) continuous trends t, Y = make_token_activation_trends(p) save_multi_curve_svg(t, Y, args.outdir / "token_activation_trends.svg") # 3) discrete trends (optional) if args.with_discrete: td, X = make_discrete_trends(p, vocab=args.disc_vocab, change_rate_hz=args.disc_rate) save_discrete_svg(td, X, args.outdir / "discrete_tokens.svg") print("Wrote:") print(f" {args.outdir / 'attention_weights.svg'}") print(f" {args.outdir / 'token_activation_trends.svg'}") if args.with_discrete: print(f" {args.outdir / 'discrete_tokens.svg'}") if __name__ == "__main__": main()