Files
internal-docs/arxiv-style/fig-scripts/synth_ics_3d_waterfall.py
2026-02-09 00:24:40 +08:00

241 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
3D "final combined outcome" (time × channel × value) with:
- NO numbers on axes (tick labels removed)
- Axis *titles* kept (texts are okay)
- Reduced whitespace: tight bbox + minimal margins
- White background (non-transparent) suitable for embedding into another SVG
Output:
default PNG, optional SVG (2D projected vectors)
Run:
uv run python synth_ics_3d_waterfall_tight.py --out ./assets/synth_ics_3d.png
uv run python synth_ics_3d_waterfall_tight.py --out ./assets/synth_ics_3d.svg --format svg
"""
from __future__ import annotations
import argparse
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
@dataclass
class Params:
seed: int = 7
seconds: float = 10.0
fs: int = 220
n_cont: int = 5
n_disc: int = 2
disc_vocab: int = 8
disc_change_rate_hz: float = 1.1
# view
elev: float = 25.0
azim: float = -58.0
# figure size (smaller, more "cube-like")
fig_w: float = 5.4
fig_h: float = 5.0
# discrete rendering
disc_z_scale: float = 0.45
disc_z_offset: float = -1.4
# margins (figure fraction)
left: float = 0.03
right: float = 0.99
bottom: float = 0.03
top: float = 0.99
def _smooth(x: np.ndarray, win: int) -> np.ndarray:
win = max(3, int(win) | 1)
k = np.ones(win, dtype=float)
k /= k.sum()
return np.convolve(x, k, mode="same")
def make_continuous(p: Params) -> tuple[np.ndarray, np.ndarray]:
rng = np.random.default_rng(p.seed)
T = int(p.seconds * p.fs)
t = np.linspace(0, p.seconds, T, endpoint=False)
Y = []
base_freqs = [0.08, 0.10, 0.12, 0.09, 0.11]
mid_freqs = [0.55, 0.70, 0.85, 0.62, 0.78]
for i in range(p.n_cont):
f1 = base_freqs[i % len(base_freqs)]
f2 = mid_freqs[i % len(mid_freqs)]
ph = rng.uniform(0, 2 * np.pi)
y = (
0.95 * np.sin(2 * np.pi * f1 * t + ph)
+ 0.28 * np.sin(2 * np.pi * f2 * t + 0.65 * ph)
)
bumps = np.zeros_like(t)
for _ in range(rng.integers(2, 4)):
mu = rng.uniform(0.8, p.seconds - 0.8)
sig = rng.uniform(0.25, 0.80)
bumps += np.exp(-0.5 * ((t - mu) / (sig + 1e-9)) ** 2)
y += 0.55 * bumps
noise = _smooth(rng.normal(0, 1, size=T), win=int(p.fs * 0.05))
y += 0.10 * noise
y = (y - y.mean()) / (y.std() + 1e-9)
Y.append(y)
return t, np.vstack(Y) # (n_cont, T)
def make_discrete(p: Params, t: np.ndarray) -> np.ndarray:
rng = np.random.default_rng(p.seed + 123)
T = len(t)
expected_changes = max(1, int(p.seconds * p.disc_change_rate_hz))
X = np.zeros((p.n_disc, T), dtype=int)
for c in range(p.n_disc):
k = rng.poisson(expected_changes) + 1
pts = np.unique(rng.integers(0, T, size=k))
pts = np.sort(np.concatenate([[0], pts, [T]]))
cur = rng.integers(0, p.disc_vocab)
for a, b in zip(pts[:-1], pts[1:]):
if a != 0 and rng.random() < 0.85:
cur = rng.integers(0, p.disc_vocab)
X[c, a:b] = cur
return X
def style_3d_axes(ax):
# Make panes white but less visually heavy
try:
# Keep pane fill ON (white background) but reduce edge prominence
ax.xaxis.pane.set_edgecolor("0.7")
ax.yaxis.pane.set_edgecolor("0.7")
ax.zaxis.pane.set_edgecolor("0.7")
except Exception:
pass
ax.grid(True, linewidth=0.4, alpha=0.30)
def remove_tick_numbers_keep_axis_titles(ax):
# Remove tick labels (numbers) and tick marks, keep axis titles
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])
ax.tick_params(
axis="both",
which="both",
length=0, # no tick marks
pad=0,
)
# 3D has separate tick_params for z on some versions; this still works broadly:
try:
ax.zaxis.set_tick_params(length=0, pad=0)
except Exception:
pass
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--out", type=Path, default=Path("synth_ics_3d.png"))
ap.add_argument("--format", choices=["png", "svg"], default="png")
ap.add_argument("--seed", type=int, default=7)
ap.add_argument("--seconds", type=float, default=10.0)
ap.add_argument("--fs", type=int, default=220)
ap.add_argument("--n-cont", type=int, default=5)
ap.add_argument("--n-disc", type=int, default=2)
ap.add_argument("--disc-vocab", type=int, default=8)
ap.add_argument("--disc-rate", type=float, default=1.1)
ap.add_argument("--elev", type=float, default=25.0)
ap.add_argument("--azim", type=float, default=-58.0)
ap.add_argument("--fig-w", type=float, default=5.4)
ap.add_argument("--fig-h", type=float, default=5.0)
ap.add_argument("--disc-z-scale", type=float, default=0.45)
ap.add_argument("--disc-z-offset", type=float, default=-1.4)
args = ap.parse_args()
p = Params(
seed=args.seed,
seconds=args.seconds,
fs=args.fs,
n_cont=args.n_cont,
n_disc=args.n_disc,
disc_vocab=args.disc_vocab,
disc_change_rate_hz=args.disc_rate,
elev=args.elev,
azim=args.azim,
fig_w=args.fig_w,
fig_h=args.fig_h,
disc_z_scale=args.disc_z_scale,
disc_z_offset=args.disc_z_offset,
)
t, Yc = make_continuous(p)
Xd = make_discrete(p, t)
fig = plt.figure(figsize=(p.fig_w, p.fig_h), dpi=220, facecolor="white")
ax = fig.add_subplot(111, projection="3d")
style_3d_axes(ax)
# Reduce whitespace around axes (tight placement)
fig.subplots_adjust(left=p.left, right=p.right, bottom=p.bottom, top=p.top)
# Draw continuous channels
for i in range(p.n_cont):
y = np.full_like(t, fill_value=i, dtype=float)
z = Yc[i]
ax.plot(t, y, z, linewidth=2.0)
# Draw discrete channels as steps
for j in range(p.n_disc):
ch = p.n_cont + j
y = np.full_like(t, fill_value=ch, dtype=float)
z = p.disc_z_offset + p.disc_z_scale * Xd[j].astype(float)
ax.step(t, y, z, where="post", linewidth=2.2)
# Axis titles kept
ax.set_xlabel("time")
ax.set_ylabel("channel")
ax.set_zlabel("value")
# Remove numeric tick labels + tick marks
remove_tick_numbers_keep_axis_titles(ax)
# Camera
ax.view_init(elev=p.elev, azim=p.azim)
# Save tightly (minimize white border)
args.out.parent.mkdir(parents=True, exist_ok=True)
save_kwargs = dict(bbox_inches="tight", pad_inches=0.03, facecolor="white")
if args.format == "svg" or args.out.suffix.lower() == ".svg":
fig.savefig(args.out, format="svg", **save_kwargs)
else:
fig.savefig(args.out, format="png", **save_kwargs)
plt.close(fig)
print(f"Wrote: {args.out}")
if __name__ == "__main__":
main()