forked from manbo/internal-docs
238 lines
7.0 KiB
Python
238 lines
7.0 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Draw *separate* SVG figures for:
|
|
1) Continuous channels (multiple smooth curves per figure)
|
|
2) Discrete channels (multiple step-like/token curves per figure)
|
|
|
|
Outputs (default):
|
|
out/continuous_channels.svg
|
|
out/discrete_channels.svg
|
|
|
|
Notes:
|
|
- Transparent background (good for draw.io / LaTeX / diagrams).
|
|
- No axes/frames by default (diagram-friendly).
|
|
- Curves are synthetic placeholders; replace `make_*_channels()` with your real data.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
# ----------------------------
|
|
# Data generators (placeholders)
|
|
# ----------------------------
|
|
|
|
@dataclass
|
|
class GenParams:
|
|
seconds: float = 10.0
|
|
fs: int = 200
|
|
seed: int = 7
|
|
n_cont: int = 6 # number of continuous channels (curves)
|
|
n_disc: int = 5 # number of discrete channels (curves)
|
|
disc_vocab: int = 8 # token/vocab size for discrete channels
|
|
disc_change_rate_hz: float = 1.2 # how often discrete tokens change
|
|
|
|
|
|
def make_continuous_channels(p: GenParams) -> tuple[np.ndarray, np.ndarray]:
|
|
"""
|
|
Returns:
|
|
t: shape (T,)
|
|
Y: shape (n_cont, T)
|
|
"""
|
|
rng = np.random.default_rng(p.seed)
|
|
T = int(p.seconds * p.fs)
|
|
t = np.linspace(0, p.seconds, T, endpoint=False)
|
|
|
|
Y = []
|
|
for i in range(p.n_cont):
|
|
# Multi-scale smooth-ish signals
|
|
f1 = 0.15 + 0.06 * i
|
|
f2 = 0.8 + 0.15 * (i % 3)
|
|
phase = rng.uniform(0, 2 * np.pi)
|
|
y = (
|
|
0.9 * np.sin(2 * np.pi * f1 * t + phase)
|
|
+ 0.35 * np.sin(2 * np.pi * f2 * t + 1.3 * phase)
|
|
)
|
|
# Add mild colored-ish noise by smoothing white noise
|
|
w = rng.normal(0, 1, size=T)
|
|
w = np.convolve(w, np.ones(9) / 9.0, mode="same")
|
|
y = y + 0.15 * w
|
|
|
|
# Normalize each channel for consistent visual scale
|
|
y = (y - np.mean(y)) / (np.std(y) + 1e-9)
|
|
y = 0.8 * y + 0.15 * i # vertical offset to separate curves a bit
|
|
Y.append(y)
|
|
|
|
return t, np.vstack(Y)
|
|
|
|
|
|
def make_discrete_channels(p: GenParams) -> tuple[np.ndarray, np.ndarray]:
|
|
"""
|
|
Discrete channels as piecewise-constant token IDs (integers).
|
|
Returns:
|
|
t: shape (T,)
|
|
X: shape (n_disc, T) (integers in [0, disc_vocab-1])
|
|
"""
|
|
rng = np.random.default_rng(p.seed + 100)
|
|
T = int(p.seconds * p.fs)
|
|
t = np.linspace(0, p.seconds, T, endpoint=False)
|
|
|
|
# expected number of changes per channel
|
|
expected_changes = int(max(1, p.seconds * p.disc_change_rate_hz))
|
|
|
|
X = np.zeros((p.n_disc, T), dtype=int)
|
|
for c in range(p.n_disc):
|
|
# pick change points
|
|
k = rng.poisson(expected_changes) + 1
|
|
change_pts = np.unique(rng.integers(0, T, size=k))
|
|
change_pts = np.sort(np.concatenate([[0], change_pts, [T]]))
|
|
|
|
cur = rng.integers(0, p.disc_vocab)
|
|
for a, b in zip(change_pts[:-1], change_pts[1:]):
|
|
# occasional token jump
|
|
if a != 0:
|
|
if rng.random() < 0.85:
|
|
cur = rng.integers(0, p.disc_vocab)
|
|
X[c, a:b] = cur
|
|
|
|
return t, X
|
|
|
|
|
|
# ----------------------------
|
|
# Plotting helpers
|
|
# ----------------------------
|
|
|
|
def _make_transparent_figure(width_in: float, height_in: float) -> tuple[plt.Figure, plt.Axes]:
|
|
fig = plt.figure(figsize=(width_in, height_in), dpi=200)
|
|
fig.patch.set_alpha(0.0)
|
|
ax = fig.add_axes([0.03, 0.03, 0.94, 0.94])
|
|
ax.patch.set_alpha(0.0)
|
|
return fig, ax
|
|
|
|
|
|
def save_continuous_channels_svg(
|
|
t: np.ndarray,
|
|
Y: np.ndarray,
|
|
out_path: Path,
|
|
*,
|
|
lw: float = 2.0,
|
|
clean: bool = True,
|
|
) -> None:
|
|
"""
|
|
Plot multiple continuous curves in one figure and save SVG.
|
|
Y shape: (n_cont, T)
|
|
"""
|
|
fig, ax = _make_transparent_figure(width_in=6.0, height_in=2.2)
|
|
|
|
# Let matplotlib choose different colors automatically (good defaults).
|
|
for i in range(Y.shape[0]):
|
|
ax.plot(t, Y[i], linewidth=lw)
|
|
|
|
if clean:
|
|
ax.set_axis_off()
|
|
else:
|
|
ax.set_xlabel("t")
|
|
ax.set_ylabel("value")
|
|
|
|
# Set limits with padding
|
|
y_all = Y.reshape(-1)
|
|
ymin, ymax = float(np.min(y_all)), float(np.max(y_all))
|
|
ypad = 0.08 * (ymax - ymin + 1e-9)
|
|
ax.set_xlim(t[0], t[-1])
|
|
ax.set_ylim(ymin - ypad, ymax + ypad)
|
|
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
fig.savefig(out_path, format="svg", bbox_inches="tight", pad_inches=0.0, transparent=True)
|
|
plt.close(fig)
|
|
|
|
|
|
def save_discrete_channels_svg(
|
|
t: np.ndarray,
|
|
X: np.ndarray,
|
|
out_path: Path,
|
|
*,
|
|
lw: float = 2.0,
|
|
clean: bool = True,
|
|
vertical_spacing: float = 1.25,
|
|
) -> None:
|
|
"""
|
|
Plot multiple discrete (piecewise-constant) curves in one figure and save SVG.
|
|
X shape: (n_disc, T) integers.
|
|
|
|
We draw each channel as a step plot, offset vertically so curves don't overlap.
|
|
"""
|
|
fig, ax = _make_transparent_figure(width_in=6.0, height_in=2.2)
|
|
|
|
for i in range(X.shape[0]):
|
|
y = X[i].astype(float) + i * vertical_spacing
|
|
ax.step(t, y, where="post", linewidth=lw)
|
|
|
|
if clean:
|
|
ax.set_axis_off()
|
|
else:
|
|
ax.set_xlabel("t")
|
|
ax.set_ylabel("token id (offset)")
|
|
|
|
y_all = (X.astype(float) + np.arange(X.shape[0])[:, None] * vertical_spacing).reshape(-1)
|
|
ymin, ymax = float(np.min(y_all)), float(np.max(y_all))
|
|
ypad = 0.10 * (ymax - ymin + 1e-9)
|
|
ax.set_xlim(t[0], t[-1])
|
|
ax.set_ylim(ymin - ypad, ymax + ypad)
|
|
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
fig.savefig(out_path, format="svg", bbox_inches="tight", pad_inches=0.0, transparent=True)
|
|
plt.close(fig)
|
|
|
|
|
|
# ----------------------------
|
|
# CLI
|
|
# ----------------------------
|
|
|
|
def main() -> None:
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--outdir", type=Path, default=Path("out"))
|
|
ap.add_argument("--seed", type=int, default=7)
|
|
ap.add_argument("--seconds", type=float, default=10.0)
|
|
ap.add_argument("--fs", type=int, default=200)
|
|
|
|
ap.add_argument("--n-cont", type=int, default=6)
|
|
ap.add_argument("--n-disc", type=int, default=5)
|
|
ap.add_argument("--disc-vocab", type=int, default=8)
|
|
ap.add_argument("--disc-change-rate", type=float, default=1.2)
|
|
|
|
ap.add_argument("--keep-axes", action="store_true", help="Show axes/labels (default: off)")
|
|
args = ap.parse_args()
|
|
|
|
p = GenParams(
|
|
seconds=args.seconds,
|
|
fs=args.fs,
|
|
seed=args.seed,
|
|
n_cont=args.n_cont,
|
|
n_disc=args.n_disc,
|
|
disc_vocab=args.disc_vocab,
|
|
disc_change_rate_hz=args.disc_change_rate,
|
|
)
|
|
|
|
t_c, Y = make_continuous_channels(p)
|
|
t_d, X = make_discrete_channels(p)
|
|
|
|
cont_path = args.outdir / "continuous_channels.svg"
|
|
disc_path = args.outdir / "discrete_channels.svg"
|
|
|
|
save_continuous_channels_svg(t_c, Y, cont_path, clean=not args.keep_axes)
|
|
save_discrete_channels_svg(t_d, X, disc_path, clean=not args.keep_axes)
|
|
|
|
print("Wrote:")
|
|
print(f" {cont_path}")
|
|
print(f" {disc_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|