forked from manbo/internal-docs
Add: python scripts for figure generation
This commit is contained in:
188
arxiv-style/fig-scripts/make_ddpm_like_svg.py
Normal file
188
arxiv-style/fig-scripts/make_ddpm_like_svg.py
Normal file
@@ -0,0 +1,188 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
DDPM-like residual curve SVGs (separate files, fixed colors):
|
||||
- noisy_residual.svg (blue)
|
||||
- denoised_residual.svg (purple)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
@dataclass
|
||||
class DDPMStyleParams:
|
||||
seconds: float = 12.0
|
||||
fs: int = 250
|
||||
seed: int = 7
|
||||
|
||||
baseline_amp: float = 0.10
|
||||
mid_wiggle_amp: float = 0.18
|
||||
colored_noise_amp: float = 0.65
|
||||
colored_alpha: float = 1.0
|
||||
|
||||
burst_rate_hz: float = 0.30
|
||||
burst_amp: float = 0.9
|
||||
burst_width_ms: float = 55
|
||||
|
||||
denoise_sigmas_ms: tuple[float, ...] = (25, 60, 140)
|
||||
denoise_weights: tuple[float, ...] = (0.25, 0.35, 0.40)
|
||||
denoise_texture_keep: float = 0.10
|
||||
|
||||
|
||||
def gaussian_smooth(x: np.ndarray, sigma_samples: float) -> np.ndarray:
|
||||
if sigma_samples <= 0:
|
||||
return x.copy()
|
||||
radius = int(np.ceil(4 * sigma_samples))
|
||||
k = np.arange(-radius, radius + 1, dtype=float)
|
||||
kernel = np.exp(-(k**2) / (2 * sigma_samples**2))
|
||||
kernel /= kernel.sum()
|
||||
return np.convolve(x, kernel, mode="same")
|
||||
|
||||
|
||||
def colored_noise_1_f(n: int, rng: np.random.Generator, alpha: float) -> np.ndarray:
|
||||
white = rng.normal(0, 1, size=n)
|
||||
spec = np.fft.rfft(white)
|
||||
|
||||
freqs = np.fft.rfftfreq(n, d=1.0)
|
||||
scale = np.ones_like(freqs)
|
||||
nonzero = freqs > 0
|
||||
scale[nonzero] = 1.0 / (freqs[nonzero] ** (alpha / 2.0))
|
||||
|
||||
spec *= scale
|
||||
x = np.fft.irfft(spec, n=n)
|
||||
|
||||
x = x - np.mean(x)
|
||||
x = x / (np.std(x) + 1e-9)
|
||||
return x
|
||||
|
||||
|
||||
def make_ddpm_like_residual(p: DDPMStyleParams) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
rng = np.random.default_rng(p.seed)
|
||||
n = int(p.seconds * p.fs)
|
||||
t = np.linspace(0, p.seconds, n, endpoint=False)
|
||||
|
||||
baseline = (
|
||||
0.8 * np.sin(2 * np.pi * 0.18 * t + 0.4)
|
||||
+ 0.35 * np.sin(2 * np.pi * 0.06 * t + 2.2)
|
||||
) * p.baseline_amp
|
||||
|
||||
mid = (
|
||||
0.9 * np.sin(2 * np.pi * 0.9 * t + 1.1)
|
||||
+ 0.5 * np.sin(2 * np.pi * 1.6 * t + 0.2)
|
||||
+ 0.3 * np.sin(2 * np.pi * 2.4 * t + 2.6)
|
||||
) * p.mid_wiggle_amp
|
||||
|
||||
col = colored_noise_1_f(n, rng, alpha=p.colored_alpha) * p.colored_noise_amp
|
||||
|
||||
expected = p.burst_rate_hz * p.seconds
|
||||
k = rng.poisson(expected)
|
||||
impulses = np.zeros(n)
|
||||
if k > 0:
|
||||
idx = rng.integers(0, n, size=k)
|
||||
impulses[idx] = rng.normal(loc=1.0, scale=0.35, size=k)
|
||||
|
||||
width = max(int(p.fs * (p.burst_width_ms / 1000.0)), 7)
|
||||
u = np.arange(width)
|
||||
kernel = np.exp(-u / (p.fs * 0.012)) * np.hanning(width)
|
||||
kernel /= (kernel.max() + 1e-9)
|
||||
bursts = np.convolve(impulses, kernel, mode="same") * p.burst_amp
|
||||
|
||||
noisy = baseline + mid + col + bursts
|
||||
|
||||
sigmas_samples = [(ms / 1000.0) * p.fs / 3.0 for ms in p.denoise_sigmas_ms]
|
||||
smooths = [gaussian_smooth(noisy, s) for s in sigmas_samples]
|
||||
|
||||
den_base = np.zeros_like(noisy)
|
||||
for w, sm in zip(p.denoise_weights, smooths):
|
||||
den_base += w * sm
|
||||
|
||||
hf = noisy - gaussian_smooth(noisy, sigma_samples=p.fs * 0.03)
|
||||
denoised = den_base + p.denoise_texture_keep * (hf / (np.std(hf) + 1e-9)) * (0.10 * np.std(den_base))
|
||||
|
||||
return t, noisy, denoised
|
||||
|
||||
|
||||
def save_single_curve_svg(
|
||||
t: np.ndarray,
|
||||
y: np.ndarray,
|
||||
out_path: Path,
|
||||
*,
|
||||
color: str,
|
||||
lw: float = 2.2,
|
||||
) -> None:
|
||||
fig = plt.figure(figsize=(5.4, 1.6), dpi=200)
|
||||
|
||||
# Make figure background transparent
|
||||
fig.patch.set_alpha(0.0)
|
||||
|
||||
ax = fig.add_axes([0.03, 0.03, 0.94, 0.94])
|
||||
|
||||
# Make axes background transparent
|
||||
ax.patch.set_alpha(0.0)
|
||||
|
||||
ax.plot(t, y, linewidth=lw, color=color)
|
||||
|
||||
# clean, diagram-friendly
|
||||
ax.set_axis_off()
|
||||
ymin, ymax = np.min(y), np.max(y)
|
||||
ypad = 0.08 * (ymax - ymin + 1e-9)
|
||||
ax.set_xlim(t[0], t[-1])
|
||||
ax.set_ylim(ymin - ypad, ymax + ypad)
|
||||
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fig.savefig(
|
||||
out_path,
|
||||
format="svg",
|
||||
bbox_inches="tight",
|
||||
pad_inches=0.0,
|
||||
transparent=True, # <-- key for transparent output
|
||||
)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--outdir", type=Path, default=Path("."))
|
||||
ap.add_argument("--seed", type=int, default=7)
|
||||
ap.add_argument("--seconds", type=float, default=12.0)
|
||||
ap.add_argument("--fs", type=int, default=250)
|
||||
|
||||
ap.add_argument("--alpha", type=float, default=1.0)
|
||||
ap.add_argument("--noise-amp", type=float, default=0.65)
|
||||
ap.add_argument("--texture-keep", type=float, default=0.10)
|
||||
|
||||
ap.add_argument("--prefix", type=str, default="")
|
||||
args = ap.parse_args()
|
||||
|
||||
p = DDPMStyleParams(
|
||||
seconds=args.seconds,
|
||||
fs=args.fs,
|
||||
seed=args.seed,
|
||||
colored_alpha=args.alpha,
|
||||
colored_noise_amp=args.noise_amp,
|
||||
denoise_texture_keep=args.texture_keep,
|
||||
)
|
||||
|
||||
t, noisy, den = make_ddpm_like_residual(p)
|
||||
|
||||
outdir = args.outdir
|
||||
noisy_path = outdir / f"{args.prefix}noisy_residual.svg"
|
||||
den_path = outdir / f"{args.prefix}denoised_residual.svg"
|
||||
|
||||
# Fixed colors as you requested
|
||||
save_single_curve_svg(t, noisy, noisy_path, color="blue")
|
||||
save_single_curve_svg(t, den, den_path, color="purple")
|
||||
|
||||
print("Wrote:")
|
||||
print(f" {noisy_path}")
|
||||
print(f" {den_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user