From 9e1e7338a26f293c2f554e2687a21c9b63fc8bb4 Mon Sep 17 00:00:00 2001 From: MingzheYang Date: Tue, 27 Jan 2026 19:27:00 +0800 Subject: [PATCH] Fix quantile transform scaling and document --- docs/decisions.md | 7 +++++++ example/data_utils.py | 2 ++ example/export_samples.py | 7 ++++--- report.md | 4 ++-- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/docs/decisions.md b/docs/decisions.md index 564243e..ca8858f 100644 --- a/docs/decisions.md +++ b/docs/decisions.md @@ -48,3 +48,10 @@ - `example/prepare_data.py` - `example/export_samples.py` - `example/config.json` + +## 2026-01-27 — Quantile transform without extra standardization +- **Decision**: When quantile transform is enabled, skip mean/std normalization (quantile output already ~N(0,1)). +- **Why**: Prevent scale mismatch that pushed values to max bounds and blew up KS. +- **Files**: + - `example/data_utils.py` + - `example/export_samples.py` diff --git a/example/data_utils.py b/example/data_utils.py index 3ca3cb3..29c730a 100755 --- a/example/data_utils.py +++ b/example/data_utils.py @@ -288,6 +288,8 @@ def normalize_cont( if not quantile_probs or not quantile_values: raise ValueError("use_quantile_transform enabled but quantile stats missing") x = apply_quantile_transform(x, cont_cols, quantile_probs, quantile_values) + # quantile transform already targets N(0,1); skip extra standardization + return x mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device) std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device) return (x - mean_t) / std_t diff --git a/example/export_samples.py b/example/export_samples.py index f9ca0a1..680dec2 100644 --- a/example/export_samples.py +++ b/example/export_samples.py @@ -271,11 +271,12 @@ def main(): if args.clip_k > 0: x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k) - mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype) - std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype) - x_cont = x_cont * std_vec + mean_vec if use_quantile: x_cont = inverse_quantile_transform(x_cont, cont_cols, quantile_probs, quantile_values) + else: + mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype) + std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype) + x_cont = x_cont * std_vec + mean_vec for i, c in enumerate(cont_cols): if transforms.get(c) == "log1p": x_cont[:, :, i] = torch.expm1(x_cont[:, :, i]) diff --git a/report.md b/report.md index 356cd3c..8db76b9 100644 --- a/report.md +++ b/report.md @@ -144,7 +144,7 @@ Defined in `example/data_utils.py` + `example/prepare_data.py`. Key steps: - Streaming mean/std/min/max + int-like detection - Optional **log1p transform** for heavy-tailed continuous columns -- Optional **quantile transform** (TabDDPM-style) for continuous columns +- Optional **quantile transform** (TabDDPM-style) for continuous columns (skips extra standardization) - Discrete vocab + most frequent token - Windowed batching with **shuffle buffer** @@ -160,7 +160,7 @@ Export process: - Diffusion generates residuals - Output: `trend + residual` - De-normalize continuous values -- Inverse quantile transform (if enabled) +- Inverse quantile transform (if enabled; no extra de-standardization) - Bound to observed min/max (clamp or sigmoid mapping) - Restore discrete tokens from vocab - Write to CSV