forked from manbo/internal-docs
Add: python scripts for figure generation
This commit is contained in:
318
arxiv-style/fig-scripts/draw_synthetic_ics_optionB.py
Normal file
318
arxiv-style/fig-scripts/draw_synthetic_ics_optionB.py
Normal file
@@ -0,0 +1,318 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Option B: "Synthetic ICS Data" as a mini process-story strip (high-level features)
|
||||
- ONE SVG, transparent background
|
||||
- Two frames by default: "steady/normal" -> "disturbance/recovery"
|
||||
- Each frame contains:
|
||||
- Top: multiple continuous feature curves
|
||||
- Bottom: discrete/categorical strip (colored blocks)
|
||||
- A vertical dashed alignment line crossing both
|
||||
- Optional shaded regime window
|
||||
- A right-pointing arrow between frames
|
||||
|
||||
No text, no numbers (axes lines only). Good for draw.io embedding.
|
||||
|
||||
Run:
|
||||
uv run python draw_synthetic_ics_optionB.py --out ./assets/synth_ics_optionB.svg
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.patches import Rectangle, FancyArrowPatch
|
||||
|
||||
|
||||
@dataclass
|
||||
class Params:
|
||||
seed: int = 7
|
||||
seconds: float = 8.0
|
||||
fs: int = 250
|
||||
|
||||
# Two-frame story
|
||||
n_frames: int = 2
|
||||
|
||||
# Per-frame visuals
|
||||
n_curves: int = 3
|
||||
n_bins: int = 32
|
||||
disc_vocab: int = 8
|
||||
|
||||
# Layout
|
||||
width_in: float = 8.2
|
||||
height_in: float = 2.4
|
||||
# Relative layout inside the figure
|
||||
margin_left: float = 0.05
|
||||
margin_right: float = 0.05
|
||||
margin_bottom: float = 0.12
|
||||
margin_top: float = 0.10
|
||||
frame_gap: float = 0.08 # gap (figure fraction) between frames (space for arrow)
|
||||
|
||||
# Styling
|
||||
curve_lw: float = 2.1
|
||||
ghost_lw: float = 1.8
|
||||
strip_height: float = 0.65
|
||||
strip_gap_frac: float = 0.12
|
||||
|
||||
# Cues
|
||||
show_alignment_line: bool = True
|
||||
align_x_frac: float = 0.60
|
||||
show_regime_window: bool = True
|
||||
regime_start_frac: float = 0.30
|
||||
regime_end_frac: float = 0.46
|
||||
show_real_ghost: bool = False # keep default off for cleaner story
|
||||
show_axes_spines: bool = True # axes lines only (no ticks/labels)
|
||||
|
||||
|
||||
# ---------- helpers ----------
|
||||
|
||||
def _smooth(x: np.ndarray, win: int) -> np.ndarray:
|
||||
win = max(3, int(win) | 1)
|
||||
k = np.ones(win, dtype=float)
|
||||
k /= k.sum()
|
||||
return np.convolve(x, k, mode="same")
|
||||
|
||||
|
||||
def _axes_only(ax: plt.Axes, *, keep_spines: bool) -> None:
|
||||
ax.set_xlabel("")
|
||||
ax.set_ylabel("")
|
||||
ax.set_title("")
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
ax.tick_params(
|
||||
axis="both",
|
||||
which="both",
|
||||
bottom=False,
|
||||
left=False,
|
||||
top=False,
|
||||
right=False,
|
||||
labelbottom=False,
|
||||
labelleft=False,
|
||||
)
|
||||
ax.grid(False)
|
||||
if keep_spines:
|
||||
for s in ("left", "right", "top", "bottom"):
|
||||
ax.spines[s].set_visible(True)
|
||||
else:
|
||||
for s in ("left", "right", "top", "bottom"):
|
||||
ax.spines[s].set_visible(False)
|
||||
|
||||
|
||||
def make_frame_continuous(seed: int, seconds: float, fs: int, n_curves: int, style: str) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
style:
|
||||
- "steady": smoother, smaller bumps
|
||||
- "disturb": larger bumps and more variance
|
||||
"""
|
||||
rng = np.random.default_rng(seed)
|
||||
T = int(seconds * fs)
|
||||
t = np.linspace(0, seconds, T, endpoint=False)
|
||||
|
||||
amp_bump = 0.40 if style == "steady" else 0.85
|
||||
amp_noise = 0.09 if style == "steady" else 0.14
|
||||
amp_scale = 0.38 if style == "steady" else 0.46
|
||||
|
||||
base_freqs = [0.10, 0.08, 0.12, 0.09]
|
||||
mid_freqs = [0.65, 0.78, 0.90, 0.72]
|
||||
|
||||
Y = []
|
||||
for i in range(n_curves):
|
||||
f_slow = base_freqs[i % len(base_freqs)]
|
||||
f_mid = mid_freqs[i % len(mid_freqs)]
|
||||
ph = rng.uniform(0, 2 * np.pi)
|
||||
|
||||
y = (
|
||||
0.95 * np.sin(2 * np.pi * f_slow * t + ph)
|
||||
+ 0.28 * np.sin(2 * np.pi * f_mid * t + 0.65 * ph)
|
||||
)
|
||||
|
||||
bumps = np.zeros_like(t)
|
||||
n_bumps = 2 if style == "steady" else 3
|
||||
for _ in range(n_bumps):
|
||||
mu = rng.uniform(0.9, seconds - 0.9)
|
||||
sig = rng.uniform(0.35, 0.75) if style == "steady" else rng.uniform(0.20, 0.55)
|
||||
bumps += np.exp(-0.5 * ((t - mu) / (sig + 1e-9)) ** 2)
|
||||
y += amp_bump * bumps
|
||||
|
||||
noise = _smooth(rng.normal(0, 1, size=T), win=int(fs * 0.04))
|
||||
y += amp_noise * noise
|
||||
|
||||
y = (y - y.mean()) / (y.std() + 1e-9)
|
||||
y *= amp_scale
|
||||
Y.append(y)
|
||||
|
||||
return t, np.vstack(Y)
|
||||
|
||||
|
||||
def make_frame_discrete(seed: int, n_bins: int, vocab: int, style: str) -> np.ndarray:
|
||||
"""
|
||||
style:
|
||||
- "steady": fewer transitions
|
||||
- "disturb": more transitions
|
||||
"""
|
||||
rng = np.random.default_rng(seed + 111)
|
||||
ids = np.zeros(n_bins, dtype=int)
|
||||
|
||||
p_change = 0.20 if style == "steady" else 0.38
|
||||
cur = rng.integers(0, vocab)
|
||||
for i in range(n_bins):
|
||||
if i == 0 or rng.random() < p_change:
|
||||
cur = rng.integers(0, vocab)
|
||||
ids[i] = cur
|
||||
return ids
|
||||
|
||||
|
||||
def draw_frame(ax_top: plt.Axes, ax_bot: plt.Axes, t: np.ndarray, Y: np.ndarray, ids: np.ndarray, p: Params) -> None:
|
||||
# Optional cues
|
||||
x0, x1 = float(t[0]), float(t[-1])
|
||||
span = x1 - x0
|
||||
|
||||
if p.show_regime_window:
|
||||
rs = x0 + p.regime_start_frac * span
|
||||
re = x0 + p.regime_end_frac * span
|
||||
ax_top.axvspan(rs, re, alpha=0.12) # default color
|
||||
ax_bot.axvspan(rs, re, alpha=0.12)
|
||||
|
||||
if p.show_alignment_line:
|
||||
vx = x0 + p.align_x_frac * span
|
||||
ax_top.axvline(vx, linestyle="--", linewidth=1.15, alpha=0.7)
|
||||
ax_bot.axvline(vx, linestyle="--", linewidth=1.15, alpha=0.7)
|
||||
|
||||
# Curves
|
||||
curve_colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#9467bd"]
|
||||
for i in range(Y.shape[0]):
|
||||
ax_top.plot(t, Y[i], linewidth=p.curve_lw, color=curve_colors[i % len(curve_colors)])
|
||||
|
||||
ymin, ymax = float(Y.min()), float(Y.max())
|
||||
ypad = 0.10 * (ymax - ymin + 1e-9)
|
||||
ax_top.set_xlim(x0, x1)
|
||||
ax_top.set_ylim(ymin - ypad, ymax + ypad)
|
||||
|
||||
# Discrete strip
|
||||
palette = [
|
||||
"#e41a1c", "#377eb8", "#4daf4a", "#984ea3",
|
||||
"#ff7f00", "#ffff33", "#a65628", "#f781bf",
|
||||
]
|
||||
|
||||
ax_bot.set_xlim(x0, x1)
|
||||
ax_bot.set_ylim(0, 1)
|
||||
|
||||
n = len(ids)
|
||||
bin_w = span / n
|
||||
gap = p.strip_gap_frac * bin_w
|
||||
y = (1 - p.strip_height) / 2
|
||||
|
||||
for i, cat in enumerate(ids):
|
||||
left = x0 + i * bin_w + gap / 2
|
||||
width = bin_w - gap
|
||||
ax_bot.add_patch(
|
||||
Rectangle((left, y), width, p.strip_height, facecolor=palette[int(cat) % len(palette)], edgecolor="none")
|
||||
)
|
||||
|
||||
# Axes-only style
|
||||
_axes_only(ax_top, keep_spines=p.show_axes_spines)
|
||||
_axes_only(ax_bot, keep_spines=p.show_axes_spines)
|
||||
|
||||
|
||||
# ---------- main drawing ----------
|
||||
|
||||
def draw_optionB(out_path: Path, p: Params) -> None:
|
||||
fig = plt.figure(figsize=(p.width_in, p.height_in), dpi=200)
|
||||
fig.patch.set_alpha(0.0)
|
||||
|
||||
# Compute frame layout in figure coordinates
|
||||
# Each frame has two stacked axes: top curves and bottom strip.
|
||||
usable_w = 1.0 - p.margin_left - p.margin_right
|
||||
usable_h = 1.0 - p.margin_bottom - p.margin_top
|
||||
|
||||
# Leave gap between frames for arrow
|
||||
total_gap = p.frame_gap * (p.n_frames - 1)
|
||||
frame_w = (usable_w - total_gap) / p.n_frames
|
||||
|
||||
# Within each frame: vertical split
|
||||
top_h = usable_h * 0.70
|
||||
bot_h = usable_h * 0.18
|
||||
v_gap = usable_h * 0.06
|
||||
# bottoms
|
||||
bot_y = p.margin_bottom
|
||||
top_y = bot_y + bot_h + v_gap
|
||||
|
||||
axes_pairs = []
|
||||
for f in range(p.n_frames):
|
||||
left = p.margin_left + f * (frame_w + p.frame_gap)
|
||||
ax_top = fig.add_axes([left, top_y, frame_w, top_h])
|
||||
ax_bot = fig.add_axes([left, bot_y, frame_w, bot_h], sharex=ax_top)
|
||||
ax_top.patch.set_alpha(0.0)
|
||||
ax_bot.patch.set_alpha(0.0)
|
||||
axes_pairs.append((ax_top, ax_bot))
|
||||
|
||||
# Data per frame
|
||||
styles = ["steady", "disturb"] if p.n_frames == 2 else ["steady"] * (p.n_frames - 1) + ["disturb"]
|
||||
for idx, ((ax_top, ax_bot), style) in enumerate(zip(axes_pairs, styles)):
|
||||
t, Y = make_frame_continuous(p.seed + 10 * idx, p.seconds, p.fs, p.n_curves, style=style)
|
||||
ids = make_frame_discrete(p.seed + 10 * idx, p.n_bins, p.disc_vocab, style=style)
|
||||
draw_frame(ax_top, ax_bot, t, Y, ids, p)
|
||||
|
||||
# Add a visual arrow between frames (in figure coordinates)
|
||||
if p.n_frames >= 2:
|
||||
for f in range(p.n_frames - 1):
|
||||
# center between frame f and f+1
|
||||
x_left = p.margin_left + f * (frame_w + p.frame_gap) + frame_w
|
||||
x_right = p.margin_left + (f + 1) * (frame_w + p.frame_gap)
|
||||
x_mid = (x_left + x_right) / 2
|
||||
# arrow y in the middle of the frame stack
|
||||
y_mid = bot_y + (bot_h + v_gap + top_h) / 2
|
||||
|
||||
arr = FancyArrowPatch(
|
||||
(x_mid - 0.015, y_mid),
|
||||
(x_mid + 0.015, y_mid),
|
||||
transform=fig.transFigure,
|
||||
arrowstyle="-|>",
|
||||
mutation_scale=18,
|
||||
linewidth=1.6,
|
||||
color="black",
|
||||
)
|
||||
fig.patches.append(arr)
|
||||
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fig.savefig(out_path, format="svg", transparent=True, bbox_inches="tight", pad_inches=0.0)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--out", type=Path, default=Path("synth_ics_optionB.svg"))
|
||||
ap.add_argument("--seed", type=int, default=7)
|
||||
ap.add_argument("--seconds", type=float, default=8.0)
|
||||
ap.add_argument("--fs", type=int, default=250)
|
||||
ap.add_argument("--frames", type=int, default=2, choices=[2, 3], help="2 or 3 frames (story strip)")
|
||||
ap.add_argument("--curves", type=int, default=3)
|
||||
ap.add_argument("--bins", type=int, default=32)
|
||||
ap.add_argument("--vocab", type=int, default=8)
|
||||
ap.add_argument("--no-align", action="store_true")
|
||||
ap.add_argument("--no-regime", action="store_true")
|
||||
ap.add_argument("--no-spines", action="store_true")
|
||||
args = ap.parse_args()
|
||||
|
||||
p = Params(
|
||||
seed=args.seed,
|
||||
seconds=args.seconds,
|
||||
fs=args.fs,
|
||||
n_frames=args.frames,
|
||||
n_curves=args.curves,
|
||||
n_bins=args.bins,
|
||||
disc_vocab=args.vocab,
|
||||
show_alignment_line=not args.no_align,
|
||||
show_regime_window=not args.no_regime,
|
||||
show_axes_spines=not args.no_spines,
|
||||
)
|
||||
|
||||
draw_optionB(args.out, p)
|
||||
print(f"Wrote: {args.out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user