forked from manbo/internal-docs
202 lines
5.9 KiB
Python
202 lines
5.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Draw the *Transformer section* lower-half visuals:
|
|
- Continuous channels: multiple smooth curves (like the colored trend lines)
|
|
- Discrete channels: small colored bars/ticks along the bottom
|
|
|
|
Output: ONE SVG with transparent background, axes hidden.
|
|
|
|
Run:
|
|
uv run python draw_transformer_lower_half.py --out ./assets/transformer_lower_half.svg
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.patches import Rectangle
|
|
|
|
|
|
@dataclass
|
|
class Params:
|
|
seed: int = 7
|
|
seconds: float = 10.0
|
|
fs: int = 300
|
|
|
|
# Continuous channels
|
|
n_curves: int = 3
|
|
curve_lw: float = 2.4
|
|
|
|
# Discrete bars
|
|
n_bins: int = 40 # number of discrete bars/ticks across time
|
|
bar_height: float = 0.11 # relative height inside bar strip axis
|
|
bar_gap: float = 0.08 # gap between bars (fraction of bar width)
|
|
|
|
# Canvas sizing
|
|
width_in: float = 5.8
|
|
height_in: float = 1.9
|
|
|
|
|
|
def _smooth(x: np.ndarray, win: int) -> np.ndarray:
|
|
win = max(3, int(win) | 1) # odd
|
|
k = np.ones(win, dtype=float)
|
|
k /= k.sum()
|
|
return np.convolve(x, k, mode="same")
|
|
|
|
|
|
def make_continuous_curves(p: Params) -> tuple[np.ndarray, np.ndarray]:
|
|
"""
|
|
Produce 3 smooth curves with gentle long-term temporal patterning.
|
|
Returns:
|
|
t: (T,)
|
|
Y: (n_curves, T)
|
|
"""
|
|
rng = np.random.default_rng(p.seed)
|
|
T = int(p.seconds * p.fs)
|
|
t = np.linspace(0, p.seconds, T, endpoint=False)
|
|
|
|
Y = []
|
|
base_freqs = [0.12, 0.09, 0.15]
|
|
mid_freqs = [0.65, 0.85, 0.75]
|
|
|
|
for i in range(p.n_curves):
|
|
f1 = base_freqs[i % len(base_freqs)]
|
|
f2 = mid_freqs[i % len(mid_freqs)]
|
|
ph = rng.uniform(0, 2 * np.pi)
|
|
|
|
# Smooth trend + mid wiggle
|
|
y = (
|
|
1.00 * np.sin(2 * np.pi * f1 * t + ph)
|
|
+ 0.35 * np.sin(2 * np.pi * f2 * t + 0.7 * ph)
|
|
)
|
|
|
|
# Add a couple of smooth bumps (like slow pattern changes)
|
|
bumps = np.zeros_like(t)
|
|
for _ in range(2):
|
|
mu = rng.uniform(0.8, p.seconds - 0.8)
|
|
sig = rng.uniform(0.35, 0.75)
|
|
bumps += np.exp(-0.5 * ((t - mu) / sig) ** 2)
|
|
y += 0.55 * bumps
|
|
|
|
# Mild smooth noise
|
|
noise = _smooth(rng.normal(0, 1, size=T), win=int(p.fs * 0.04))
|
|
y += 0.12 * noise
|
|
|
|
# Normalize and compress amplitude to fit nicely
|
|
y = (y - y.mean()) / (y.std() + 1e-9)
|
|
y *= 0.42
|
|
|
|
Y.append(y)
|
|
|
|
return t, np.vstack(Y)
|
|
|
|
|
|
def make_discrete_bars(p: Params) -> np.ndarray:
|
|
"""
|
|
Generate discrete "token-like" bars across time bins.
|
|
Returns:
|
|
ids: (n_bins,) integer category ids
|
|
"""
|
|
rng = np.random.default_rng(p.seed + 123)
|
|
n = p.n_bins
|
|
|
|
# A piecewise-constant sequence with occasional changes (looks like discrete channel)
|
|
ids = np.zeros(n, dtype=int)
|
|
cur = rng.integers(0, 8)
|
|
for i in range(n):
|
|
if i == 0 or rng.random() < 0.25:
|
|
cur = rng.integers(0, 8)
|
|
ids[i] = cur
|
|
return ids
|
|
|
|
|
|
def draw_transformer_lower_half_svg(out_path: Path, p: Params) -> None:
|
|
# --- Figure + transparent background ---
|
|
fig = plt.figure(figsize=(p.width_in, p.height_in), dpi=200)
|
|
fig.patch.set_alpha(0.0)
|
|
|
|
# Two stacked axes: curves (top), bars (bottom)
|
|
# Tight, diagram-style layout
|
|
ax_curves = fig.add_axes([0.06, 0.28, 0.90, 0.68]) # [left, bottom, width, height]
|
|
ax_bars = fig.add_axes([0.06, 0.10, 0.90, 0.14])
|
|
|
|
ax_curves.patch.set_alpha(0.0)
|
|
ax_bars.patch.set_alpha(0.0)
|
|
|
|
for ax in (ax_curves, ax_bars):
|
|
ax.set_axis_off()
|
|
|
|
# --- Data ---
|
|
t, Y = make_continuous_curves(p)
|
|
ids = make_discrete_bars(p)
|
|
|
|
# --- Continuous curves (explicit colors to match the “multi-colored” look) ---
|
|
# Feel free to swap these hex colors to match your figure theme.
|
|
curve_colors = ["#1f77b4", "#ff7f0e", "#2ca02c"] # blue / orange / green
|
|
|
|
for i in range(Y.shape[0]):
|
|
ax_curves.plot(t, Y[i], linewidth=p.curve_lw, color=curve_colors[i % len(curve_colors)])
|
|
|
|
# Set curve bounds with padding (keeps it clean)
|
|
ymin, ymax = float(Y.min()), float(Y.max())
|
|
pad = 0.10 * (ymax - ymin + 1e-9)
|
|
ax_curves.set_xlim(t[0], t[-1])
|
|
ax_curves.set_ylim(ymin - pad, ymax + pad)
|
|
|
|
# --- Discrete bars: small colored rectangles along the timeline ---
|
|
# A small palette for categories (repeats if more categories appear)
|
|
bar_palette = [
|
|
"#e41a1c", "#377eb8", "#4daf4a", "#984ea3",
|
|
"#ff7f00", "#ffff33", "#a65628", "#f781bf",
|
|
]
|
|
|
|
# Convert bins into time spans
|
|
n = len(ids)
|
|
x0, x1 = t[0], t[-1]
|
|
total = x1 - x0
|
|
bin_w = total / n
|
|
gap = p.bar_gap * bin_w
|
|
|
|
# Draw bars in [0,1] y-space inside ax_bars
|
|
ax_bars.set_xlim(x0, x1)
|
|
ax_bars.set_ylim(0, 1)
|
|
|
|
for i, cat in enumerate(ids):
|
|
left = x0 + i * bin_w + gap / 2
|
|
width = bin_w - gap
|
|
color = bar_palette[int(cat) % len(bar_palette)]
|
|
rect = Rectangle(
|
|
(left, (1 - p.bar_height) / 2),
|
|
width,
|
|
p.bar_height,
|
|
facecolor=color,
|
|
edgecolor="none",
|
|
)
|
|
ax_bars.add_patch(rect)
|
|
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
fig.savefig(out_path, format="svg", transparent=True, bbox_inches="tight", pad_inches=0.0)
|
|
plt.close(fig)
|
|
|
|
|
|
def main() -> None:
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--out", type=Path, default=Path("transformer_lower_half.svg"))
|
|
ap.add_argument("--seed", type=int, default=7)
|
|
ap.add_argument("--seconds", type=float, default=10.0)
|
|
ap.add_argument("--fs", type=int, default=300)
|
|
ap.add_argument("--bins", type=int, default=40)
|
|
args = ap.parse_args()
|
|
|
|
p = Params(seed=args.seed, seconds=args.seconds, fs=args.fs, n_bins=args.bins)
|
|
draw_transformer_lower_half_svg(args.out, p)
|
|
print(f"Wrote: {args.out}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|