Fix quantile transform scaling and document

This commit is contained in:
2026-01-27 19:27:00 +08:00
parent 80e91443d2
commit 9e1e7338a2
4 changed files with 15 additions and 5 deletions

View File

@@ -48,3 +48,10 @@
- `example/prepare_data.py` - `example/prepare_data.py`
- `example/export_samples.py` - `example/export_samples.py`
- `example/config.json` - `example/config.json`
## 2026-01-27 — Quantile transform without extra standardization
- **Decision**: When quantile transform is enabled, skip mean/std normalization (quantile output already ~N(0,1)).
- **Why**: Prevent scale mismatch that pushed values to max bounds and blew up KS.
- **Files**:
- `example/data_utils.py`
- `example/export_samples.py`

View File

@@ -288,6 +288,8 @@ def normalize_cont(
if not quantile_probs or not quantile_values: if not quantile_probs or not quantile_values:
raise ValueError("use_quantile_transform enabled but quantile stats missing") raise ValueError("use_quantile_transform enabled but quantile stats missing")
x = apply_quantile_transform(x, cont_cols, quantile_probs, quantile_values) x = apply_quantile_transform(x, cont_cols, quantile_probs, quantile_values)
# quantile transform already targets N(0,1); skip extra standardization
return x
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device) mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)
std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device) std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device)
return (x - mean_t) / std_t return (x - mean_t) / std_t

View File

@@ -271,11 +271,12 @@ def main():
if args.clip_k > 0: if args.clip_k > 0:
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k) x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
x_cont = x_cont * std_vec + mean_vec
if use_quantile: if use_quantile:
x_cont = inverse_quantile_transform(x_cont, cont_cols, quantile_probs, quantile_values) x_cont = inverse_quantile_transform(x_cont, cont_cols, quantile_probs, quantile_values)
else:
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
x_cont = x_cont * std_vec + mean_vec
for i, c in enumerate(cont_cols): for i, c in enumerate(cont_cols):
if transforms.get(c) == "log1p": if transforms.get(c) == "log1p":
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i]) x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])

View File

@@ -144,7 +144,7 @@ Defined in `example/data_utils.py` + `example/prepare_data.py`.
Key steps: Key steps:
- Streaming mean/std/min/max + int-like detection - Streaming mean/std/min/max + int-like detection
- Optional **log1p transform** for heavy-tailed continuous columns - Optional **log1p transform** for heavy-tailed continuous columns
- Optional **quantile transform** (TabDDPM-style) for continuous columns - Optional **quantile transform** (TabDDPM-style) for continuous columns (skips extra standardization)
- Discrete vocab + most frequent token - Discrete vocab + most frequent token
- Windowed batching with **shuffle buffer** - Windowed batching with **shuffle buffer**
@@ -160,7 +160,7 @@ Export process:
- Diffusion generates residuals - Diffusion generates residuals
- Output: `trend + residual` - Output: `trend + residual`
- De-normalize continuous values - De-normalize continuous values
- Inverse quantile transform (if enabled) - Inverse quantile transform (if enabled; no extra de-standardization)
- Bound to observed min/max (clamp or sigmoid mapping) - Bound to observed min/max (clamp or sigmoid mapping)
- Restore discrete tokens from vocab - Restore discrete tokens from vocab
- Write to CSV - Write to CSV