#!/usr/bin/env python3 """ Draw the *Transformer section* lower-half visuals: - Continuous channels: multiple smooth curves (like the colored trend lines) - Discrete channels: small colored bars/ticks along the bottom Output: ONE SVG with transparent background, axes hidden. Run: uv run python draw_transformer_lower_half.py --out ./assets/transformer_lower_half.svg """ from __future__ import annotations import argparse from dataclasses import dataclass from pathlib import Path import numpy as np import matplotlib.pyplot as plt from matplotlib.patches import Rectangle @dataclass class Params: seed: int = 7 seconds: float = 10.0 fs: int = 300 # Continuous channels n_curves: int = 3 curve_lw: float = 2.4 # Discrete bars n_bins: int = 40 # number of discrete bars/ticks across time bar_height: float = 0.11 # relative height inside bar strip axis bar_gap: float = 0.08 # gap between bars (fraction of bar width) # Canvas sizing width_in: float = 5.8 height_in: float = 1.9 def _smooth(x: np.ndarray, win: int) -> np.ndarray: win = max(3, int(win) | 1) # odd k = np.ones(win, dtype=float) k /= k.sum() return np.convolve(x, k, mode="same") def make_continuous_curves(p: Params) -> tuple[np.ndarray, np.ndarray]: """ Produce 3 smooth curves with gentle long-term temporal patterning. Returns: t: (T,) Y: (n_curves, T) """ rng = np.random.default_rng(p.seed) T = int(p.seconds * p.fs) t = np.linspace(0, p.seconds, T, endpoint=False) Y = [] base_freqs = [0.12, 0.09, 0.15] mid_freqs = [0.65, 0.85, 0.75] for i in range(p.n_curves): f1 = base_freqs[i % len(base_freqs)] f2 = mid_freqs[i % len(mid_freqs)] ph = rng.uniform(0, 2 * np.pi) # Smooth trend + mid wiggle y = ( 1.00 * np.sin(2 * np.pi * f1 * t + ph) + 0.35 * np.sin(2 * np.pi * f2 * t + 0.7 * ph) ) # Add a couple of smooth bumps (like slow pattern changes) bumps = np.zeros_like(t) for _ in range(2): mu = rng.uniform(0.8, p.seconds - 0.8) sig = rng.uniform(0.35, 0.75) bumps += np.exp(-0.5 * ((t - mu) / sig) ** 2) y += 0.55 * bumps # Mild smooth noise noise = _smooth(rng.normal(0, 1, size=T), win=int(p.fs * 0.04)) y += 0.12 * noise # Normalize and compress amplitude to fit nicely y = (y - y.mean()) / (y.std() + 1e-9) y *= 0.42 Y.append(y) return t, np.vstack(Y) def make_discrete_bars(p: Params) -> np.ndarray: """ Generate discrete "token-like" bars across time bins. Returns: ids: (n_bins,) integer category ids """ rng = np.random.default_rng(p.seed + 123) n = p.n_bins # A piecewise-constant sequence with occasional changes (looks like discrete channel) ids = np.zeros(n, dtype=int) cur = rng.integers(0, 8) for i in range(n): if i == 0 or rng.random() < 0.25: cur = rng.integers(0, 8) ids[i] = cur return ids def draw_transformer_lower_half_svg(out_path: Path, p: Params) -> None: # --- Figure + transparent background --- fig = plt.figure(figsize=(p.width_in, p.height_in), dpi=200) fig.patch.set_alpha(0.0) # Two stacked axes: curves (top), bars (bottom) # Tight, diagram-style layout ax_curves = fig.add_axes([0.06, 0.28, 0.90, 0.68]) # [left, bottom, width, height] ax_bars = fig.add_axes([0.06, 0.10, 0.90, 0.14]) ax_curves.patch.set_alpha(0.0) ax_bars.patch.set_alpha(0.0) for ax in (ax_curves, ax_bars): ax.set_axis_off() # --- Data --- t, Y = make_continuous_curves(p) ids = make_discrete_bars(p) # --- Continuous curves (explicit colors to match the “multi-colored” look) --- # Feel free to swap these hex colors to match your figure theme. curve_colors = ["#1f77b4", "#ff7f0e", "#2ca02c"] # blue / orange / green for i in range(Y.shape[0]): ax_curves.plot(t, Y[i], linewidth=p.curve_lw, color=curve_colors[i % len(curve_colors)]) # Set curve bounds with padding (keeps it clean) ymin, ymax = float(Y.min()), float(Y.max()) pad = 0.10 * (ymax - ymin + 1e-9) ax_curves.set_xlim(t[0], t[-1]) ax_curves.set_ylim(ymin - pad, ymax + pad) # --- Discrete bars: small colored rectangles along the timeline --- # A small palette for categories (repeats if more categories appear) bar_palette = [ "#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33", "#a65628", "#f781bf", ] # Convert bins into time spans n = len(ids) x0, x1 = t[0], t[-1] total = x1 - x0 bin_w = total / n gap = p.bar_gap * bin_w # Draw bars in [0,1] y-space inside ax_bars ax_bars.set_xlim(x0, x1) ax_bars.set_ylim(0, 1) for i, cat in enumerate(ids): left = x0 + i * bin_w + gap / 2 width = bin_w - gap color = bar_palette[int(cat) % len(bar_palette)] rect = Rectangle( (left, (1 - p.bar_height) / 2), width, p.bar_height, facecolor=color, edgecolor="none", ) ax_bars.add_patch(rect) out_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_path, format="svg", transparent=True, bbox_inches="tight", pad_inches=0.0) plt.close(fig) def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--out", type=Path, default=Path("transformer_lower_half.svg")) ap.add_argument("--seed", type=int, default=7) ap.add_argument("--seconds", type=float, default=10.0) ap.add_argument("--fs", type=int, default=300) ap.add_argument("--bins", type=int, default=40) args = ap.parse_args() p = Params(seed=args.seed, seconds=args.seconds, fs=args.fs, n_bins=args.bins) draw_transformer_lower_half_svg(args.out, p) print(f"Wrote: {args.out}") if __name__ == "__main__": main()