Files
internal-docs/arxiv-style/fig-scripts/draw_channels.py
2026-02-09 00:24:40 +08:00

238 lines
7.0 KiB
Python

#!/usr/bin/env python3
"""
Draw *separate* SVG figures for:
1) Continuous channels (multiple smooth curves per figure)
2) Discrete channels (multiple step-like/token curves per figure)
Outputs (default):
out/continuous_channels.svg
out/discrete_channels.svg
Notes:
- Transparent background (good for draw.io / LaTeX / diagrams).
- No axes/frames by default (diagram-friendly).
- Curves are synthetic placeholders; replace `make_*_channels()` with your real data.
"""
from __future__ import annotations
import argparse
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
# ----------------------------
# Data generators (placeholders)
# ----------------------------
@dataclass
class GenParams:
seconds: float = 10.0
fs: int = 200
seed: int = 7
n_cont: int = 6 # number of continuous channels (curves)
n_disc: int = 5 # number of discrete channels (curves)
disc_vocab: int = 8 # token/vocab size for discrete channels
disc_change_rate_hz: float = 1.2 # how often discrete tokens change
def make_continuous_channels(p: GenParams) -> tuple[np.ndarray, np.ndarray]:
"""
Returns:
t: shape (T,)
Y: shape (n_cont, T)
"""
rng = np.random.default_rng(p.seed)
T = int(p.seconds * p.fs)
t = np.linspace(0, p.seconds, T, endpoint=False)
Y = []
for i in range(p.n_cont):
# Multi-scale smooth-ish signals
f1 = 0.15 + 0.06 * i
f2 = 0.8 + 0.15 * (i % 3)
phase = rng.uniform(0, 2 * np.pi)
y = (
0.9 * np.sin(2 * np.pi * f1 * t + phase)
+ 0.35 * np.sin(2 * np.pi * f2 * t + 1.3 * phase)
)
# Add mild colored-ish noise by smoothing white noise
w = rng.normal(0, 1, size=T)
w = np.convolve(w, np.ones(9) / 9.0, mode="same")
y = y + 0.15 * w
# Normalize each channel for consistent visual scale
y = (y - np.mean(y)) / (np.std(y) + 1e-9)
y = 0.8 * y + 0.15 * i # vertical offset to separate curves a bit
Y.append(y)
return t, np.vstack(Y)
def make_discrete_channels(p: GenParams) -> tuple[np.ndarray, np.ndarray]:
"""
Discrete channels as piecewise-constant token IDs (integers).
Returns:
t: shape (T,)
X: shape (n_disc, T) (integers in [0, disc_vocab-1])
"""
rng = np.random.default_rng(p.seed + 100)
T = int(p.seconds * p.fs)
t = np.linspace(0, p.seconds, T, endpoint=False)
# expected number of changes per channel
expected_changes = int(max(1, p.seconds * p.disc_change_rate_hz))
X = np.zeros((p.n_disc, T), dtype=int)
for c in range(p.n_disc):
# pick change points
k = rng.poisson(expected_changes) + 1
change_pts = np.unique(rng.integers(0, T, size=k))
change_pts = np.sort(np.concatenate([[0], change_pts, [T]]))
cur = rng.integers(0, p.disc_vocab)
for a, b in zip(change_pts[:-1], change_pts[1:]):
# occasional token jump
if a != 0:
if rng.random() < 0.85:
cur = rng.integers(0, p.disc_vocab)
X[c, a:b] = cur
return t, X
# ----------------------------
# Plotting helpers
# ----------------------------
def _make_transparent_figure(width_in: float, height_in: float) -> tuple[plt.Figure, plt.Axes]:
fig = plt.figure(figsize=(width_in, height_in), dpi=200)
fig.patch.set_alpha(0.0)
ax = fig.add_axes([0.03, 0.03, 0.94, 0.94])
ax.patch.set_alpha(0.0)
return fig, ax
def save_continuous_channels_svg(
t: np.ndarray,
Y: np.ndarray,
out_path: Path,
*,
lw: float = 2.0,
clean: bool = True,
) -> None:
"""
Plot multiple continuous curves in one figure and save SVG.
Y shape: (n_cont, T)
"""
fig, ax = _make_transparent_figure(width_in=6.0, height_in=2.2)
# Let matplotlib choose different colors automatically (good defaults).
for i in range(Y.shape[0]):
ax.plot(t, Y[i], linewidth=lw)
if clean:
ax.set_axis_off()
else:
ax.set_xlabel("t")
ax.set_ylabel("value")
# Set limits with padding
y_all = Y.reshape(-1)
ymin, ymax = float(np.min(y_all)), float(np.max(y_all))
ypad = 0.08 * (ymax - ymin + 1e-9)
ax.set_xlim(t[0], t[-1])
ax.set_ylim(ymin - ypad, ymax + ypad)
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, format="svg", bbox_inches="tight", pad_inches=0.0, transparent=True)
plt.close(fig)
def save_discrete_channels_svg(
t: np.ndarray,
X: np.ndarray,
out_path: Path,
*,
lw: float = 2.0,
clean: bool = True,
vertical_spacing: float = 1.25,
) -> None:
"""
Plot multiple discrete (piecewise-constant) curves in one figure and save SVG.
X shape: (n_disc, T) integers.
We draw each channel as a step plot, offset vertically so curves don't overlap.
"""
fig, ax = _make_transparent_figure(width_in=6.0, height_in=2.2)
for i in range(X.shape[0]):
y = X[i].astype(float) + i * vertical_spacing
ax.step(t, y, where="post", linewidth=lw)
if clean:
ax.set_axis_off()
else:
ax.set_xlabel("t")
ax.set_ylabel("token id (offset)")
y_all = (X.astype(float) + np.arange(X.shape[0])[:, None] * vertical_spacing).reshape(-1)
ymin, ymax = float(np.min(y_all)), float(np.max(y_all))
ypad = 0.10 * (ymax - ymin + 1e-9)
ax.set_xlim(t[0], t[-1])
ax.set_ylim(ymin - ypad, ymax + ypad)
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, format="svg", bbox_inches="tight", pad_inches=0.0, transparent=True)
plt.close(fig)
# ----------------------------
# CLI
# ----------------------------
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--outdir", type=Path, default=Path("out"))
ap.add_argument("--seed", type=int, default=7)
ap.add_argument("--seconds", type=float, default=10.0)
ap.add_argument("--fs", type=int, default=200)
ap.add_argument("--n-cont", type=int, default=6)
ap.add_argument("--n-disc", type=int, default=5)
ap.add_argument("--disc-vocab", type=int, default=8)
ap.add_argument("--disc-change-rate", type=float, default=1.2)
ap.add_argument("--keep-axes", action="store_true", help="Show axes/labels (default: off)")
args = ap.parse_args()
p = GenParams(
seconds=args.seconds,
fs=args.fs,
seed=args.seed,
n_cont=args.n_cont,
n_disc=args.n_disc,
disc_vocab=args.disc_vocab,
disc_change_rate_hz=args.disc_change_rate,
)
t_c, Y = make_continuous_channels(p)
t_d, X = make_discrete_channels(p)
cont_path = args.outdir / "continuous_channels.svg"
disc_path = args.outdir / "discrete_channels.svg"
save_continuous_channels_svg(t_c, Y, cont_path, clean=not args.keep_axes)
save_discrete_channels_svg(t_d, X, disc_path, clean=not args.keep_axes)
print("Wrote:")
print(f" {cont_path}")
print(f" {disc_path}")
if __name__ == "__main__":
main()