#!/usr/bin/env python3 """ Draw *separate* SVG figures for: 1) Continuous channels (multiple smooth curves per figure) 2) Discrete channels (multiple step-like/token curves per figure) Outputs (default): out/continuous_channels.svg out/discrete_channels.svg Notes: - Transparent background (good for draw.io / LaTeX / diagrams). - No axes/frames by default (diagram-friendly). - Curves are synthetic placeholders; replace `make_*_channels()` with your real data. """ from __future__ import annotations import argparse from dataclasses import dataclass from pathlib import Path import numpy as np import matplotlib.pyplot as plt # ---------------------------- # Data generators (placeholders) # ---------------------------- @dataclass class GenParams: seconds: float = 10.0 fs: int = 200 seed: int = 7 n_cont: int = 6 # number of continuous channels (curves) n_disc: int = 5 # number of discrete channels (curves) disc_vocab: int = 8 # token/vocab size for discrete channels disc_change_rate_hz: float = 1.2 # how often discrete tokens change def make_continuous_channels(p: GenParams) -> tuple[np.ndarray, np.ndarray]: """ Returns: t: shape (T,) Y: shape (n_cont, T) """ rng = np.random.default_rng(p.seed) T = int(p.seconds * p.fs) t = np.linspace(0, p.seconds, T, endpoint=False) Y = [] for i in range(p.n_cont): # Multi-scale smooth-ish signals f1 = 0.15 + 0.06 * i f2 = 0.8 + 0.15 * (i % 3) phase = rng.uniform(0, 2 * np.pi) y = ( 0.9 * np.sin(2 * np.pi * f1 * t + phase) + 0.35 * np.sin(2 * np.pi * f2 * t + 1.3 * phase) ) # Add mild colored-ish noise by smoothing white noise w = rng.normal(0, 1, size=T) w = np.convolve(w, np.ones(9) / 9.0, mode="same") y = y + 0.15 * w # Normalize each channel for consistent visual scale y = (y - np.mean(y)) / (np.std(y) + 1e-9) y = 0.8 * y + 0.15 * i # vertical offset to separate curves a bit Y.append(y) return t, np.vstack(Y) def make_discrete_channels(p: GenParams) -> tuple[np.ndarray, np.ndarray]: """ Discrete channels as piecewise-constant token IDs (integers). Returns: t: shape (T,) X: shape (n_disc, T) (integers in [0, disc_vocab-1]) """ rng = np.random.default_rng(p.seed + 100) T = int(p.seconds * p.fs) t = np.linspace(0, p.seconds, T, endpoint=False) # expected number of changes per channel expected_changes = int(max(1, p.seconds * p.disc_change_rate_hz)) X = np.zeros((p.n_disc, T), dtype=int) for c in range(p.n_disc): # pick change points k = rng.poisson(expected_changes) + 1 change_pts = np.unique(rng.integers(0, T, size=k)) change_pts = np.sort(np.concatenate([[0], change_pts, [T]])) cur = rng.integers(0, p.disc_vocab) for a, b in zip(change_pts[:-1], change_pts[1:]): # occasional token jump if a != 0: if rng.random() < 0.85: cur = rng.integers(0, p.disc_vocab) X[c, a:b] = cur return t, X # ---------------------------- # Plotting helpers # ---------------------------- def _make_transparent_figure(width_in: float, height_in: float) -> tuple[plt.Figure, plt.Axes]: fig = plt.figure(figsize=(width_in, height_in), dpi=200) fig.patch.set_alpha(0.0) ax = fig.add_axes([0.03, 0.03, 0.94, 0.94]) ax.patch.set_alpha(0.0) return fig, ax def save_continuous_channels_svg( t: np.ndarray, Y: np.ndarray, out_path: Path, *, lw: float = 2.0, clean: bool = True, ) -> None: """ Plot multiple continuous curves in one figure and save SVG. Y shape: (n_cont, T) """ fig, ax = _make_transparent_figure(width_in=6.0, height_in=2.2) # Let matplotlib choose different colors automatically (good defaults). for i in range(Y.shape[0]): ax.plot(t, Y[i], linewidth=lw) if clean: ax.set_axis_off() else: ax.set_xlabel("t") ax.set_ylabel("value") # Set limits with padding y_all = Y.reshape(-1) ymin, ymax = float(np.min(y_all)), float(np.max(y_all)) ypad = 0.08 * (ymax - ymin + 1e-9) ax.set_xlim(t[0], t[-1]) ax.set_ylim(ymin - ypad, ymax + ypad) out_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_path, format="svg", bbox_inches="tight", pad_inches=0.0, transparent=True) plt.close(fig) def save_discrete_channels_svg( t: np.ndarray, X: np.ndarray, out_path: Path, *, lw: float = 2.0, clean: bool = True, vertical_spacing: float = 1.25, ) -> None: """ Plot multiple discrete (piecewise-constant) curves in one figure and save SVG. X shape: (n_disc, T) integers. We draw each channel as a step plot, offset vertically so curves don't overlap. """ fig, ax = _make_transparent_figure(width_in=6.0, height_in=2.2) for i in range(X.shape[0]): y = X[i].astype(float) + i * vertical_spacing ax.step(t, y, where="post", linewidth=lw) if clean: ax.set_axis_off() else: ax.set_xlabel("t") ax.set_ylabel("token id (offset)") y_all = (X.astype(float) + np.arange(X.shape[0])[:, None] * vertical_spacing).reshape(-1) ymin, ymax = float(np.min(y_all)), float(np.max(y_all)) ypad = 0.10 * (ymax - ymin + 1e-9) ax.set_xlim(t[0], t[-1]) ax.set_ylim(ymin - ypad, ymax + ypad) out_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_path, format="svg", bbox_inches="tight", pad_inches=0.0, transparent=True) plt.close(fig) # ---------------------------- # CLI # ---------------------------- def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--outdir", type=Path, default=Path("out")) ap.add_argument("--seed", type=int, default=7) ap.add_argument("--seconds", type=float, default=10.0) ap.add_argument("--fs", type=int, default=200) ap.add_argument("--n-cont", type=int, default=6) ap.add_argument("--n-disc", type=int, default=5) ap.add_argument("--disc-vocab", type=int, default=8) ap.add_argument("--disc-change-rate", type=float, default=1.2) ap.add_argument("--keep-axes", action="store_true", help="Show axes/labels (default: off)") args = ap.parse_args() p = GenParams( seconds=args.seconds, fs=args.fs, seed=args.seed, n_cont=args.n_cont, n_disc=args.n_disc, disc_vocab=args.disc_vocab, disc_change_rate_hz=args.disc_change_rate, ) t_c, Y = make_continuous_channels(p) t_d, X = make_discrete_channels(p) cont_path = args.outdir / "continuous_channels.svg" disc_path = args.outdir / "discrete_channels.svg" save_continuous_channels_svg(t_c, Y, cont_path, clean=not args.keep_axes) save_discrete_channels_svg(t_d, X, disc_path, clean=not args.keep_axes) print("Wrote:") print(f" {cont_path}") print(f" {disc_path}") if __name__ == "__main__": main()