diff --git a/docs/decisions.md b/docs/decisions.md
index b93e2bd..564243e 100644
--- a/docs/decisions.md
+++ b/docs/decisions.md
@@ -39,3 +39,12 @@
- **Why**: Avoid blind reweighting and find the specific features causing KS to stay high.
- **Files**:
- `example/diagnose_ks.py`
+
+## 2026-01-26 — Quantile transform + sigmoid bounds for continuous features
+- **Decision**: Add optional quantile normalization (TabDDPM-style) and sigmoid-based bounds to reduce KS spikes.
+- **Why**: KS failures are dominated by boundary pile-up and tail mismatch.
+- **Files**:
+ - `example/data_utils.py`
+ - `example/prepare_data.py`
+ - `example/export_samples.py`
+ - `example/config.json`
diff --git a/example/config.json b/example/config.json
index 7a90bcd..26301fe 100644
--- a/example/config.json
+++ b/example/config.json
@@ -42,6 +42,9 @@
"cont_loss_eps": 1e-6,
"cont_target": "x0",
"cont_clamp_x0": 5.0,
+ "use_quantile_transform": true,
+ "quantile_bins": 1001,
+ "cont_bound_mode": "sigmoid",
"shuffle_buffer": 256,
"use_temporal_stage1": true,
"temporal_hidden_dim": 256,
diff --git a/example/config_no_temporal.json b/example/config_no_temporal.json
index 47ef065..a061ba4 100644
--- a/example/config_no_temporal.json
+++ b/example/config_no_temporal.json
@@ -42,6 +42,9 @@
"cont_loss_eps": 1e-6,
"cont_target": "x0",
"cont_clamp_x0": 5.0,
+ "use_quantile_transform": true,
+ "quantile_bins": 1001,
+ "cont_bound_mode": "sigmoid",
"shuffle_buffer": 1024,
"use_temporal_stage1": false,
"sample_batch_size": 4,
diff --git a/example/config_temporal_strong.json b/example/config_temporal_strong.json
index 6dbb8ce..eb7c960 100644
--- a/example/config_temporal_strong.json
+++ b/example/config_temporal_strong.json
@@ -42,6 +42,9 @@
"cont_loss_eps": 1e-6,
"cont_target": "x0",
"cont_clamp_x0": 5.0,
+ "use_quantile_transform": true,
+ "quantile_bins": 1001,
+ "cont_bound_mode": "sigmoid",
"shuffle_buffer": 1024,
"use_temporal_stage1": true,
"temporal_hidden_dim": 512,
diff --git a/example/data_utils.py b/example/data_utils.py
index aa36195..3ca3cb3 100755
--- a/example/data_utils.py
+++ b/example/data_utils.py
@@ -138,6 +138,7 @@ def compute_cont_stats(
cont_cols: List[str],
max_rows: Optional[int] = None,
transforms: Optional[Dict[str, str]] = None,
+ quantile_bins: Optional[int] = None,
):
"""Compute stats on (optionally transformed) values. Returns raw + transformed stats."""
# First pass (raw) for metadata and raw mean/std
@@ -147,10 +148,11 @@ def compute_cont_stats(
if transforms is None:
transforms = {c: "none" for c in cont_cols}
- # Second pass for transformed mean/std
+ # Second pass for transformed mean/std (and optional quantiles)
count = {c: 0 for c in cont_cols}
mean = {c: 0.0 for c in cont_cols}
m2 = {c: 0.0 for c in cont_cols}
+ quantile_values = {c: [] for c in cont_cols} if quantile_bins and quantile_bins > 1 else None
for i, row in enumerate(iter_rows(path)):
for c in cont_cols:
raw_val = row[c]
@@ -161,6 +163,8 @@ def compute_cont_stats(
if x < 0:
x = 0.0
x = math.log1p(x)
+ if quantile_values is not None:
+ quantile_values[c].append(x)
n = count[c] + 1
delta = x - mean[c]
mean[c] += delta / n
@@ -178,6 +182,25 @@ def compute_cont_stats(
var = 0.0
std[c] = var ** 0.5 if var > 0 else 1.0
+ quantile_probs = None
+ quantile_table = None
+ if quantile_values is not None:
+ quantile_probs = [i / (quantile_bins - 1) for i in range(quantile_bins)]
+ quantile_table = {}
+ for c in cont_cols:
+ vals = quantile_values[c]
+ if not vals:
+ quantile_table[c] = [0.0 for _ in quantile_probs]
+ continue
+ vals.sort()
+ n = len(vals)
+ qvals = []
+ for p in quantile_probs:
+ idx = int(round(p * (n - 1)))
+ idx = max(0, min(n - 1, idx))
+ qvals.append(float(vals[idx]))
+ quantile_table[c] = qvals
+
return {
"mean": mean,
"std": std,
@@ -191,6 +214,8 @@ def compute_cont_stats(
"skew": raw["skew"],
"all_pos": raw["all_pos"],
"max_rows": max_rows,
+ "quantile_probs": quantile_probs,
+ "quantile_values": quantile_table,
}
@@ -249,6 +274,9 @@ def normalize_cont(
mean: Dict[str, float],
std: Dict[str, float],
transforms: Optional[Dict[str, str]] = None,
+ quantile_probs: Optional[List[float]] = None,
+ quantile_values: Optional[Dict[str, List[float]]] = None,
+ use_quantile: bool = False,
):
import torch
@@ -256,11 +284,64 @@ def normalize_cont(
for i, c in enumerate(cont_cols):
if transforms.get(c) == "log1p":
x[:, :, i] = torch.log1p(torch.clamp(x[:, :, i], min=0))
+ if use_quantile:
+ if not quantile_probs or not quantile_values:
+ raise ValueError("use_quantile_transform enabled but quantile stats missing")
+ x = apply_quantile_transform(x, cont_cols, quantile_probs, quantile_values)
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)
std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device)
return (x - mean_t) / std_t
+def _normal_cdf(x):
+ import torch
+ return 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
+
+
+def _normal_ppf(p):
+ import torch
+ eps = 1e-6
+ p = torch.clamp(p, eps, 1.0 - eps)
+ return math.sqrt(2.0) * torch.erfinv(2.0 * p - 1.0)
+
+
+def apply_quantile_transform(x, cont_cols, quantile_probs, quantile_values):
+ import torch
+ probs_t = torch.tensor(quantile_probs, dtype=x.dtype, device=x.device)
+ for i, c in enumerate(cont_cols):
+ q_vals = torch.tensor(quantile_values[c], dtype=x.dtype, device=x.device)
+ v = x[:, :, i]
+ idx = torch.bucketize(v, q_vals)
+ idx = torch.clamp(idx, 1, q_vals.numel() - 1)
+ x0 = q_vals[idx - 1]
+ x1 = q_vals[idx]
+ p0 = probs_t[idx - 1]
+ p1 = probs_t[idx]
+ denom = torch.where((x1 - x0) == 0, torch.ones_like(x1 - x0), (x1 - x0))
+ p = p0 + (v - x0) * (p1 - p0) / denom
+ x[:, :, i] = _normal_ppf(p)
+ return x
+
+
+def inverse_quantile_transform(x, cont_cols, quantile_probs, quantile_values):
+ import torch
+ probs_t = torch.tensor(quantile_probs, dtype=x.dtype, device=x.device)
+ for i, c in enumerate(cont_cols):
+ q_vals = torch.tensor(quantile_values[c], dtype=x.dtype, device=x.device)
+ z = x[:, :, i]
+ p = _normal_cdf(z)
+ idx = torch.bucketize(p, probs_t)
+ idx = torch.clamp(idx, 1, probs_t.numel() - 1)
+ p0 = probs_t[idx - 1]
+ p1 = probs_t[idx]
+ x0 = q_vals[idx - 1]
+ x1 = q_vals[idx]
+ denom = torch.where((p1 - p0) == 0, torch.ones_like(p1 - p0), (p1 - p0))
+ v = x0 + (p - p0) * (x1 - x0) / denom
+ x[:, :, i] = v
+ return x
+
+
def windowed_batches(
path: Union[str, List[str]],
cont_cols: List[str],
@@ -273,6 +354,9 @@ def windowed_batches(
max_batches: Optional[int] = None,
return_file_id: bool = False,
transforms: Optional[Dict[str, str]] = None,
+ quantile_probs: Optional[List[float]] = None,
+ quantile_values: Optional[Dict[str, List[float]]] = None,
+ use_quantile: bool = False,
shuffle_buffer: int = 0,
):
import torch
@@ -316,7 +400,16 @@ def windowed_batches(
if len(batch_cont) == batch_size:
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
x_disc = torch.tensor(batch_disc, dtype=torch.long)
- x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms)
+ x_cont = normalize_cont(
+ x_cont,
+ cont_cols,
+ mean,
+ std,
+ transforms=transforms,
+ quantile_probs=quantile_probs,
+ quantile_values=quantile_values,
+ use_quantile=use_quantile,
+ )
if return_file_id:
x_file = torch.tensor(batch_file, dtype=torch.long)
yield x_cont, x_disc, x_file
@@ -344,7 +437,16 @@ def windowed_batches(
import torch
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
x_disc = torch.tensor(batch_disc, dtype=torch.long)
- x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms)
+ x_cont = normalize_cont(
+ x_cont,
+ cont_cols,
+ mean,
+ std,
+ transforms=transforms,
+ quantile_probs=quantile_probs,
+ quantile_values=quantile_values,
+ use_quantile=use_quantile,
+ )
if return_file_id:
x_file = torch.tensor(batch_file, dtype=torch.long)
yield x_cont, x_disc, x_file
diff --git a/example/export_samples.py b/example/export_samples.py
index 3959ca6..f9ca0a1 100644
--- a/example/export_samples.py
+++ b/example/export_samples.py
@@ -12,7 +12,7 @@ from typing import Dict, List
import torch
import torch.nn.functional as F
-from data_utils import load_split
+from data_utils import load_split, inverse_quantile_transform
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
@@ -112,6 +112,8 @@ def main():
int_like = stats.get("int_like", {})
max_decimals = stats.get("max_decimals", {})
transforms = stats.get("transform", {})
+ quantile_probs = stats.get("quantile_probs")
+ quantile_values = stats.get("quantile_values")
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
vocab = vocab_json["vocab"]
@@ -140,6 +142,8 @@ def main():
raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob)
cont_target = str(cfg.get("cont_target", "eps"))
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
+ use_quantile = bool(cfg.get("use_quantile_transform", False))
+ cont_bound_mode = str(cfg.get("cont_bound_mode", "clamp"))
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
@@ -270,15 +274,21 @@ def main():
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
x_cont = x_cont * std_vec + mean_vec
+ if use_quantile:
+ x_cont = inverse_quantile_transform(x_cont, cont_cols, quantile_probs, quantile_values)
for i, c in enumerate(cont_cols):
if transforms.get(c) == "log1p":
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])
- # clamp to observed min/max per feature
+ # bound to observed min/max per feature
if vmin and vmax:
for i, c in enumerate(cont_cols):
lo = vmin.get(c, None)
hi = vmax.get(c, None)
- if lo is not None and hi is not None:
+ if lo is None or hi is None:
+ continue
+ if cont_bound_mode == "sigmoid":
+ x_cont[:, :, i] = float(lo) + (float(hi) - float(lo)) * torch.sigmoid(x_cont[:, :, i])
+ else:
x_cont[:, :, i] = torch.clamp(x_cont[:, :, i], float(lo), float(hi))
header = read_header(data_path)
diff --git a/example/prepare_data.py b/example/prepare_data.py
index e0427ac..3c6b5fc 100755
--- a/example/prepare_data.py
+++ b/example/prepare_data.py
@@ -17,6 +17,14 @@ OUT_VOCAB = BASE_DIR / "results" / "disc_vocab.json"
def main(max_rows: Optional[int] = None):
+ config_path = BASE_DIR / "config.json"
+ use_quantile = False
+ quantile_bins = None
+ if config_path.exists():
+ cfg = json.loads(config_path.read_text(encoding="utf-8"))
+ use_quantile = bool(cfg.get("use_quantile_transform", False))
+ quantile_bins = int(cfg.get("quantile_bins", 0)) if use_quantile else None
+
split = load_split(safe_path(SPLIT_PATH))
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
@@ -28,7 +36,13 @@ def main(max_rows: Optional[int] = None):
data_paths = [safe_path(p) for p in data_paths]
transforms, _ = choose_cont_transforms(data_paths, cont_cols, max_rows=max_rows)
- cont_stats = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows, transforms=transforms)
+ cont_stats = compute_cont_stats(
+ data_paths,
+ cont_cols,
+ max_rows=max_rows,
+ transforms=transforms,
+ quantile_bins=quantile_bins,
+ )
vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows)
ensure_dir(OUT_STATS.parent)
@@ -46,6 +60,8 @@ def main(max_rows: Optional[int] = None):
"transform": cont_stats["transform"],
"skew": cont_stats["skew"],
"max_rows": cont_stats["max_rows"],
+ "quantile_probs": cont_stats["quantile_probs"],
+ "quantile_values": cont_stats["quantile_values"],
},
f,
indent=2,
diff --git a/example/results/cdf_P1_B3004.svg b/example/results/cdf_P1_B3004.svg
new file mode 100644
index 0000000..9dda8f1
--- /dev/null
+++ b/example/results/cdf_P1_B3004.svg
@@ -0,0 +1,12 @@
+
\ No newline at end of file
diff --git a/example/results/cdf_P1_LIT01.svg b/example/results/cdf_P1_LIT01.svg
new file mode 100644
index 0000000..3179261
--- /dev/null
+++ b/example/results/cdf_P1_LIT01.svg
@@ -0,0 +1,12 @@
+
\ No newline at end of file
diff --git a/example/results/cdf_P1_PCV02Z.svg b/example/results/cdf_P1_PCV02Z.svg
new file mode 100644
index 0000000..325b214
--- /dev/null
+++ b/example/results/cdf_P1_PCV02Z.svg
@@ -0,0 +1,12 @@
+
\ No newline at end of file
diff --git a/example/results/cdf_P2_MSD.svg b/example/results/cdf_P2_MSD.svg
new file mode 100644
index 0000000..c6e374f
--- /dev/null
+++ b/example/results/cdf_P2_MSD.svg
@@ -0,0 +1,12 @@
+
\ No newline at end of file
diff --git a/example/results/cdf_P2_SIT01.svg b/example/results/cdf_P2_SIT01.svg
new file mode 100644
index 0000000..3371a7f
--- /dev/null
+++ b/example/results/cdf_P2_SIT01.svg
@@ -0,0 +1,12 @@
+
\ No newline at end of file
diff --git a/example/results/cdf_P2_SIT02.svg b/example/results/cdf_P2_SIT02.svg
new file mode 100644
index 0000000..f5904db
--- /dev/null
+++ b/example/results/cdf_P2_SIT02.svg
@@ -0,0 +1,12 @@
+
\ No newline at end of file
diff --git a/example/results/cdf_P3_LCP01D.svg b/example/results/cdf_P3_LCP01D.svg
new file mode 100644
index 0000000..1006e21
--- /dev/null
+++ b/example/results/cdf_P3_LCP01D.svg
@@ -0,0 +1,12 @@
+
\ No newline at end of file
diff --git a/example/results/cdf_P3_PIT01.svg b/example/results/cdf_P3_PIT01.svg
new file mode 100644
index 0000000..fffb73a
--- /dev/null
+++ b/example/results/cdf_P3_PIT01.svg
@@ -0,0 +1,12 @@
+
\ No newline at end of file
diff --git a/example/results/cdf_P4_HT_FD.svg b/example/results/cdf_P4_HT_FD.svg
new file mode 100644
index 0000000..016f803
--- /dev/null
+++ b/example/results/cdf_P4_HT_FD.svg
@@ -0,0 +1,12 @@
+
\ No newline at end of file
diff --git a/example/results/cdf_P4_ST_PT01.svg b/example/results/cdf_P4_ST_PT01.svg
new file mode 100644
index 0000000..5c30b49
--- /dev/null
+++ b/example/results/cdf_P4_ST_PT01.svg
@@ -0,0 +1,12 @@
+
\ No newline at end of file
diff --git a/example/results/ks_diagnosis.csv b/example/results/ks_diagnosis.csv
new file mode 100644
index 0000000..864fa32
--- /dev/null
+++ b/example/results/ks_diagnosis.csv
@@ -0,0 +1 @@
+feature,ks,boundary_frac,mean_shift,std_ratio,diagnosis,gen_frac_at_min,gen_frac_at_max
diff --git a/example/results/ks_per_feature.csv b/example/results/ks_per_feature.csv
new file mode 100644
index 0000000..c1ddc23
--- /dev/null
+++ b/example/results/ks_per_feature.csv
@@ -0,0 +1,54 @@
+feature,ks,gen_frac_at_min,gen_frac_at_max,real_n,gen_n,real_min,real_max
+P2_MSD,1.0,1.0,1.0,92163,52,763.19324,763.19324
+P3_PIT01,0.9141619071227483,0.0,0.0,92163,52,-24.0,3847.0
+P2_SIT02,0.8628397930422604,0.0,0.0,92163,52,757.68005,826.50775
+P2_SIT01,0.8617182433464456,0.0,0.0,92163,52,758.0,827.0
+P3_LCP01D,0.8261961040597803,0.8269230769230769,0.0,92163,52,-8.0,13816.0
+P4_HT_FD,0.7983631008272134,0.0,0.0,92163,52,-0.0217,0.02684
+P1_B3004,0.7794726567227461,0.0,0.0,92163,52,369.75601,447.83438
+P1_LIT01,0.7761347161675927,0.0,0.0,92163,52,356.09085,459.24484
+P4_ST_PT01,0.7676921073783155,0.0,0.0,92163,52,9914.0,10330.0
+P1_PCV02Z,0.7670214728253204,0.0,0.0,92163,52,11.76605,12.04071
+P1_PIT02,0.7347702941026726,0.0,0.0,92163,52,0.17105,2.34161
+P4_ST_PO,0.7212397099119537,0.019230769230769232,0.0,92163,52,233.66968,498.60754
+P1_B2016,0.6999296397102459,0.0,0.0,92163,52,0.9508,2.0523
+P4_ST_LD,0.6933532896148046,0.0,0.0,92163,52,230.55914,499.62018
+P4_LD,0.6897361614330463,0.0,0.0,92163,52,231.33685,498.58942
+P3_LIT01,0.6615471835435378,0.0,0.0,92163,52,5047.0,19680.0
+P1_PCV01D,0.6231695265662259,0.0,0.4807692307692308,92163,52,24.95222,100.0
+P1_B2004,0.617741226038482,0.0,0.0,92163,52,0.02978,0.10196
+P2_CO_rpm,0.6100514640031582,0.0,0.0,92163,52,53993.0,54183.0
+P4_ST_GOV,0.6084888062037244,0.0,0.0,92163,52,12665.0,26898.0
+P1_FCV02Z,0.5961538461538461,0.5961538461538461,0.0,92163,52,-1.89057,97.38312
+P1_B4002,0.5783991406529735,0.0,0.0,92163,52,31.41343,33.6555
+P1_FT01Z,0.5633543078775981,0.0,0.0,92163,52,0.0,1365.69287
+P1_PCV01Z,0.547708324465266,0.0,0.5,92163,52,25.57526,100.0
+P1_B3005,0.5359248538751159,0.0,0.0,92163,52,890.07843,1121.94116
+P1_B4005,0.5101396438918004,0.019230769230769232,0.0,92163,52,0.0,100.0
+P1_FT02,0.5049748814600219,0.0,0.0,92163,52,4.99723,2005.23364
+P3_FIT01,0.497898998346575,0.0,0.0,92163,52,-27.0,5421.0
+P2_24Vdc,0.4871763572733593,0.0,0.0,92163,52,28.01351,28.04294
+P4_HT_LD,0.48082202185258727,0.6346153846153846,0.0,92163,52,-0.00723,83.04398
+P1_B400B,0.4544694642184959,0.0,0.0,92163,52,25.02598,2855.56567
+P2_VXT03,0.45055916816276176,0.0,0.0,92163,52,-2.135,0.1491
+P2_VYT03,0.4479521650186668,0.0,0.0,92163,52,4.6083,7.2547
+P2_VXT02,0.44536394131133883,0.0,0.0,92163,52,-4.3925,-1.8818
+P4_HT_PO,0.42936573912941867,0.019230769230769232,0.0,92163,52,0.05423,83.04401
+P3_LCV01D,0.4154990030205681,0.25,0.0,92163,52,-288.0,17776.0
+P1_FT03,0.41513384730565156,0.0,0.0,92163,52,187.91197,331.15381
+P1_FT02Z,0.40720829900869615,0.0,0.0,92163,52,25.02598,2856.88574
+P2_VT01,0.36856126144398005,0.0,0.0,92163,52,11.76163,12.06125
+P1_TIT02,0.3579625646534276,0.0,0.0,92163,52,34.99451,40.4419
+P1_FCV03Z,0.35665363791075844,0.0,0.0,92163,52,46.20513,75.3189
+P1_LCV01Z,0.3512624789357317,0.0,0.0,92163,52,0.29907,28.52783
+P1_FCV03D,0.30470616858592514,0.0,0.0,92163,52,45.78336,74.1622
+P4_ST_TT01,0.30430700122441934,0.0,0.21153846153846154,92163,52,27539.0,27629.0
+P2_HILout,0.3041348563873872,0.0,0.0,92163,52,673.80371,768.76831
+P4_ST_FD,0.30162947086224323,0.0,0.0,92163,52,-0.05244,0.05035
+P1_B4022,0.2862201083531769,0.0,0.0,92163,52,34.21529,38.63682
+P1_TIT01,0.2807849220319517,0.0,0.0,92163,52,34.68933,36.94763
+P1_LCV01D,0.28024261363019864,0.0,0.0,92163,52,3.17127,28.23791
+P1_FT03Z,0.24018503170386246,0.0,0.0,92163,52,867.43927,1146.92163
+P1_PIT01,0.21846515245981413,0.0,0.0,92163,52,0.88211,2.38739
+P1_FT01,0.21452397466361856,0.0,0.0,92163,52,-9.88007,462.57019
+P2_VYT02,0.18998029411101902,0.0,0.0,92163,52,2.4459,5.1248
diff --git a/example/results/ks_summary.json b/example/results/ks_summary.json
new file mode 100644
index 0000000..9fa6b90
--- /dev/null
+++ b/example/results/ks_summary.json
@@ -0,0 +1,17 @@
+{
+ "generated_rows": 52,
+ "reference_rows_per_file": 50000,
+ "stride": 10,
+ "top_k_features": [
+ "P2_MSD",
+ "P3_PIT01",
+ "P2_SIT02",
+ "P2_SIT01",
+ "P3_LCP01D",
+ "P4_HT_FD",
+ "P1_B3004",
+ "P1_LIT01",
+ "P4_ST_PT01",
+ "P1_PCV02Z"
+ ]
+}
\ No newline at end of file
diff --git a/example/train.py b/example/train.py
index 4a98c8d..3e27a29 100755
--- a/example/train.py
+++ b/example/train.py
@@ -173,6 +173,9 @@ def main():
std = stats["std"]
transforms = stats.get("transform", {})
raw_std = stats.get("raw_std", std)
+ quantile_probs = stats.get("quantile_probs")
+ quantile_values = stats.get("quantile_values")
+ use_quantile = bool(config.get("use_quantile_transform", False))
vocab = load_json(config["vocab_path"])["vocab"]
vocab_sizes = [len(vocab[c]) for c in disc_cols]
@@ -253,6 +256,9 @@ def main():
max_batches=int(config["max_batches"]),
return_file_id=False,
transforms=transforms,
+ quantile_probs=quantile_probs,
+ quantile_values=quantile_values,
+ use_quantile=use_quantile,
shuffle_buffer=int(config.get("shuffle_buffer", 0)),
)
):
@@ -284,6 +290,9 @@ def main():
max_batches=int(config["max_batches"]),
return_file_id=use_condition,
transforms=transforms,
+ quantile_probs=quantile_probs,
+ quantile_values=quantile_values,
+ use_quantile=use_quantile,
shuffle_buffer=int(config.get("shuffle_buffer", 0)),
)
):
diff --git a/report.md b/report.md
index 64e78a5..356cd3c 100644
--- a/report.md
+++ b/report.md
@@ -144,6 +144,7 @@ Defined in `example/data_utils.py` + `example/prepare_data.py`.
Key steps:
- Streaming mean/std/min/max + int-like detection
- Optional **log1p transform** for heavy-tailed continuous columns
+- Optional **quantile transform** (TabDDPM-style) for continuous columns
- Discrete vocab + most frequent token
- Windowed batching with **shuffle buffer**
@@ -159,7 +160,8 @@ Export process:
- Diffusion generates residuals
- Output: `trend + residual`
- De-normalize continuous values
-- Clamp to observed min/max
+- Inverse quantile transform (if enabled)
+- Bound to observed min/max (clamp or sigmoid mapping)
- Restore discrete tokens from vocab
- Write to CSV