Files
internal-docs/arxiv-style/fig-scripts/draw_transformer_lower_half_axes.py
2026-02-09 00:24:40 +08:00

203 lines
5.5 KiB
Python

#!/usr/bin/env python3
"""
Transformer section lower-half visuals WITH AXES ONLY:
- Axes spines visible
- NO numbers (tick labels hidden)
- NO words (axis labels removed)
- Transparent background
- One SVG output
Run:
uv run python draw_transformer_lower_half_axes_only.py --out ./assets/transformer_lower_half_axes_only.svg
"""
from __future__ import annotations
import argparse
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
@dataclass
class Params:
seed: int = 7
seconds: float = 10.0
fs: int = 300
# Continuous channels
n_curves: int = 3
curve_lw: float = 2.4
# Discrete bars
n_bins: int = 40
bar_height: float = 0.55 # fraction of the discrete-axis y-range
bar_gap: float = 0.08 # fraction of bar width
# Figure size
width_in: float = 6.6
height_in: float = 2.6
def _smooth(x: np.ndarray, win: int) -> np.ndarray:
win = max(3, int(win) | 1) # odd
k = np.ones(win, dtype=float)
k /= k.sum()
return np.convolve(x, k, mode="same")
def make_continuous_curves(p: Params) -> tuple[np.ndarray, np.ndarray]:
rng = np.random.default_rng(p.seed)
T = int(p.seconds * p.fs)
t = np.linspace(0, p.seconds, T, endpoint=False)
Y = []
base_freqs = [0.12, 0.09, 0.15]
mid_freqs = [0.65, 0.85, 0.75]
for i in range(p.n_curves):
f1 = base_freqs[i % len(base_freqs)]
f2 = mid_freqs[i % len(mid_freqs)]
ph = rng.uniform(0, 2 * np.pi)
y = (
1.00 * np.sin(2 * np.pi * f1 * t + ph)
+ 0.35 * np.sin(2 * np.pi * f2 * t + 0.7 * ph)
)
bumps = np.zeros_like(t)
for _ in range(2):
mu = rng.uniform(0.8, p.seconds - 0.8)
sig = rng.uniform(0.35, 0.75)
bumps += np.exp(-0.5 * ((t - mu) / sig) ** 2)
y += 0.55 * bumps
noise = _smooth(rng.normal(0, 1, size=T), win=int(p.fs * 0.04))
y += 0.12 * noise
y = (y - y.mean()) / (y.std() + 1e-9)
y *= 0.42
Y.append(y)
return t, np.vstack(Y)
def make_discrete_bars(p: Params) -> np.ndarray:
rng = np.random.default_rng(p.seed + 123)
n = p.n_bins
ids = np.zeros(n, dtype=int)
cur = rng.integers(0, 8)
for i in range(n):
if i == 0 or rng.random() < 0.25:
cur = rng.integers(0, 8)
ids[i] = cur
return ids
def _axes_only(ax: plt.Axes) -> None:
"""Keep spines (axes lines), remove all ticks/labels/words."""
# No labels
ax.set_xlabel("")
ax.set_ylabel("")
ax.set_title("")
# Keep spines as the only axes element
for side in ("top", "right", "bottom", "left"):
ax.spines[side].set_visible(True)
# Remove tick marks and tick labels entirely
ax.set_xticks([])
ax.set_yticks([])
ax.tick_params(
axis="both",
which="both",
bottom=False,
left=False,
top=False,
right=False,
labelbottom=False,
labelleft=False,
)
# No grid
ax.grid(False)
def draw_transformer_lower_half_svg(out_path: Path, p: Params) -> None:
fig = plt.figure(figsize=(p.width_in, p.height_in), dpi=200)
fig.patch.set_alpha(0.0)
# Two axes sharing x (top curves, bottom bars)
ax_curves = fig.add_axes([0.10, 0.38, 0.86, 0.56])
ax_bars = fig.add_axes([0.10, 0.14, 0.86, 0.18], sharex=ax_curves)
ax_curves.patch.set_alpha(0.0)
ax_bars.patch.set_alpha(0.0)
# Data
t, Y = make_continuous_curves(p)
ids = make_discrete_bars(p)
# Top: continuous curves
curve_colors = ["#1f77b4", "#ff7f0e", "#2ca02c"] # blue / orange / green
for i in range(Y.shape[0]):
ax_curves.plot(t, Y[i], linewidth=p.curve_lw, color=curve_colors[i % len(curve_colors)])
ymin, ymax = float(Y.min()), float(Y.max())
ypad = 0.10 * (ymax - ymin + 1e-9)
ax_curves.set_xlim(t[0], t[-1])
ax_curves.set_ylim(ymin - ypad, ymax + ypad)
# Bottom: discrete bars (colored strip)
bar_palette = [
"#e41a1c", "#377eb8", "#4daf4a", "#984ea3",
"#ff7f00", "#ffff33", "#a65628", "#f781bf",
]
x0, x1 = t[0], t[-1]
total = x1 - x0
n = len(ids)
bin_w = total / n
gap = p.bar_gap * bin_w
ax_bars.set_xlim(x0, x1)
ax_bars.set_ylim(0, 1)
bar_y = (1 - p.bar_height) / 2
for i, cat in enumerate(ids):
left = x0 + i * bin_w + gap / 2
width = bin_w - gap
color = bar_palette[int(cat) % len(bar_palette)]
ax_bars.add_patch(Rectangle((left, bar_y), width, p.bar_height, facecolor=color, edgecolor="none"))
# Apply "axes only" styling (no numbers/words)
_axes_only(ax_curves)
_axes_only(ax_bars)
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, format="svg", transparent=True, bbox_inches="tight", pad_inches=0.0)
plt.close(fig)
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--out", type=Path, default=Path("transformer_lower_half_axes_only.svg"))
ap.add_argument("--seed", type=int, default=7)
ap.add_argument("--seconds", type=float, default=10.0)
ap.add_argument("--fs", type=int, default=300)
ap.add_argument("--bins", type=int, default=40)
ap.add_argument("--curves", type=int, default=3)
args = ap.parse_args()
p = Params(seed=args.seed, seconds=args.seconds, fs=args.fs, n_bins=args.bins, n_curves=args.curves)
draw_transformer_lower_half_svg(args.out, p)
print(f"Wrote: {args.out}")
if __name__ == "__main__":
main()