Fix quantile transform scaling and document
This commit is contained in:
@@ -48,3 +48,10 @@
|
||||
- `example/prepare_data.py`
|
||||
- `example/export_samples.py`
|
||||
- `example/config.json`
|
||||
|
||||
## 2026-01-27 — Quantile transform without extra standardization
|
||||
- **Decision**: When quantile transform is enabled, skip mean/std normalization (quantile output already ~N(0,1)).
|
||||
- **Why**: Prevent scale mismatch that pushed values to max bounds and blew up KS.
|
||||
- **Files**:
|
||||
- `example/data_utils.py`
|
||||
- `example/export_samples.py`
|
||||
|
||||
@@ -288,6 +288,8 @@ def normalize_cont(
|
||||
if not quantile_probs or not quantile_values:
|
||||
raise ValueError("use_quantile_transform enabled but quantile stats missing")
|
||||
x = apply_quantile_transform(x, cont_cols, quantile_probs, quantile_values)
|
||||
# quantile transform already targets N(0,1); skip extra standardization
|
||||
return x
|
||||
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)
|
||||
std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device)
|
||||
return (x - mean_t) / std_t
|
||||
|
||||
@@ -271,11 +271,12 @@ def main():
|
||||
if args.clip_k > 0:
|
||||
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
|
||||
|
||||
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||
x_cont = x_cont * std_vec + mean_vec
|
||||
if use_quantile:
|
||||
x_cont = inverse_quantile_transform(x_cont, cont_cols, quantile_probs, quantile_values)
|
||||
else:
|
||||
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||
x_cont = x_cont * std_vec + mean_vec
|
||||
for i, c in enumerate(cont_cols):
|
||||
if transforms.get(c) == "log1p":
|
||||
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])
|
||||
|
||||
@@ -144,7 +144,7 @@ Defined in `example/data_utils.py` + `example/prepare_data.py`.
|
||||
Key steps:
|
||||
- Streaming mean/std/min/max + int-like detection
|
||||
- Optional **log1p transform** for heavy-tailed continuous columns
|
||||
- Optional **quantile transform** (TabDDPM-style) for continuous columns
|
||||
- Optional **quantile transform** (TabDDPM-style) for continuous columns (skips extra standardization)
|
||||
- Discrete vocab + most frequent token
|
||||
- Windowed batching with **shuffle buffer**
|
||||
|
||||
@@ -160,7 +160,7 @@ Export process:
|
||||
- Diffusion generates residuals
|
||||
- Output: `trend + residual`
|
||||
- De-normalize continuous values
|
||||
- Inverse quantile transform (if enabled)
|
||||
- Inverse quantile transform (if enabled; no extra de-standardization)
|
||||
- Bound to observed min/max (clamp or sigmoid mapping)
|
||||
- Restore discrete tokens from vocab
|
||||
- Write to CSV
|
||||
|
||||
Reference in New Issue
Block a user