Fix quantile transform scaling and document
This commit is contained in:
@@ -48,3 +48,10 @@
|
|||||||
- `example/prepare_data.py`
|
- `example/prepare_data.py`
|
||||||
- `example/export_samples.py`
|
- `example/export_samples.py`
|
||||||
- `example/config.json`
|
- `example/config.json`
|
||||||
|
|
||||||
|
## 2026-01-27 — Quantile transform without extra standardization
|
||||||
|
- **Decision**: When quantile transform is enabled, skip mean/std normalization (quantile output already ~N(0,1)).
|
||||||
|
- **Why**: Prevent scale mismatch that pushed values to max bounds and blew up KS.
|
||||||
|
- **Files**:
|
||||||
|
- `example/data_utils.py`
|
||||||
|
- `example/export_samples.py`
|
||||||
|
|||||||
@@ -288,6 +288,8 @@ def normalize_cont(
|
|||||||
if not quantile_probs or not quantile_values:
|
if not quantile_probs or not quantile_values:
|
||||||
raise ValueError("use_quantile_transform enabled but quantile stats missing")
|
raise ValueError("use_quantile_transform enabled but quantile stats missing")
|
||||||
x = apply_quantile_transform(x, cont_cols, quantile_probs, quantile_values)
|
x = apply_quantile_transform(x, cont_cols, quantile_probs, quantile_values)
|
||||||
|
# quantile transform already targets N(0,1); skip extra standardization
|
||||||
|
return x
|
||||||
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)
|
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)
|
||||||
std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device)
|
std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device)
|
||||||
return (x - mean_t) / std_t
|
return (x - mean_t) / std_t
|
||||||
|
|||||||
@@ -271,11 +271,12 @@ def main():
|
|||||||
if args.clip_k > 0:
|
if args.clip_k > 0:
|
||||||
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
|
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
|
||||||
|
|
||||||
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
|
|
||||||
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
|
|
||||||
x_cont = x_cont * std_vec + mean_vec
|
|
||||||
if use_quantile:
|
if use_quantile:
|
||||||
x_cont = inverse_quantile_transform(x_cont, cont_cols, quantile_probs, quantile_values)
|
x_cont = inverse_quantile_transform(x_cont, cont_cols, quantile_probs, quantile_values)
|
||||||
|
else:
|
||||||
|
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||||
|
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||||
|
x_cont = x_cont * std_vec + mean_vec
|
||||||
for i, c in enumerate(cont_cols):
|
for i, c in enumerate(cont_cols):
|
||||||
if transforms.get(c) == "log1p":
|
if transforms.get(c) == "log1p":
|
||||||
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])
|
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ Defined in `example/data_utils.py` + `example/prepare_data.py`.
|
|||||||
Key steps:
|
Key steps:
|
||||||
- Streaming mean/std/min/max + int-like detection
|
- Streaming mean/std/min/max + int-like detection
|
||||||
- Optional **log1p transform** for heavy-tailed continuous columns
|
- Optional **log1p transform** for heavy-tailed continuous columns
|
||||||
- Optional **quantile transform** (TabDDPM-style) for continuous columns
|
- Optional **quantile transform** (TabDDPM-style) for continuous columns (skips extra standardization)
|
||||||
- Discrete vocab + most frequent token
|
- Discrete vocab + most frequent token
|
||||||
- Windowed batching with **shuffle buffer**
|
- Windowed batching with **shuffle buffer**
|
||||||
|
|
||||||
@@ -160,7 +160,7 @@ Export process:
|
|||||||
- Diffusion generates residuals
|
- Diffusion generates residuals
|
||||||
- Output: `trend + residual`
|
- Output: `trend + residual`
|
||||||
- De-normalize continuous values
|
- De-normalize continuous values
|
||||||
- Inverse quantile transform (if enabled)
|
- Inverse quantile transform (if enabled; no extra de-standardization)
|
||||||
- Bound to observed min/max (clamp or sigmoid mapping)
|
- Bound to observed min/max (clamp or sigmoid mapping)
|
||||||
- Restore discrete tokens from vocab
|
- Restore discrete tokens from vocab
|
||||||
- Write to CSV
|
- Write to CSV
|
||||||
|
|||||||
Reference in New Issue
Block a user