diff --git a/docs/decisions.md b/docs/decisions.md index b93e2bd..564243e 100644 --- a/docs/decisions.md +++ b/docs/decisions.md @@ -39,3 +39,12 @@ - **Why**: Avoid blind reweighting and find the specific features causing KS to stay high. - **Files**: - `example/diagnose_ks.py` + +## 2026-01-26 — Quantile transform + sigmoid bounds for continuous features +- **Decision**: Add optional quantile normalization (TabDDPM-style) and sigmoid-based bounds to reduce KS spikes. +- **Why**: KS failures are dominated by boundary pile-up and tail mismatch. +- **Files**: + - `example/data_utils.py` + - `example/prepare_data.py` + - `example/export_samples.py` + - `example/config.json` diff --git a/example/config.json b/example/config.json index 7a90bcd..26301fe 100644 --- a/example/config.json +++ b/example/config.json @@ -42,6 +42,9 @@ "cont_loss_eps": 1e-6, "cont_target": "x0", "cont_clamp_x0": 5.0, + "use_quantile_transform": true, + "quantile_bins": 1001, + "cont_bound_mode": "sigmoid", "shuffle_buffer": 256, "use_temporal_stage1": true, "temporal_hidden_dim": 256, diff --git a/example/config_no_temporal.json b/example/config_no_temporal.json index 47ef065..a061ba4 100644 --- a/example/config_no_temporal.json +++ b/example/config_no_temporal.json @@ -42,6 +42,9 @@ "cont_loss_eps": 1e-6, "cont_target": "x0", "cont_clamp_x0": 5.0, + "use_quantile_transform": true, + "quantile_bins": 1001, + "cont_bound_mode": "sigmoid", "shuffle_buffer": 1024, "use_temporal_stage1": false, "sample_batch_size": 4, diff --git a/example/config_temporal_strong.json b/example/config_temporal_strong.json index 6dbb8ce..eb7c960 100644 --- a/example/config_temporal_strong.json +++ b/example/config_temporal_strong.json @@ -42,6 +42,9 @@ "cont_loss_eps": 1e-6, "cont_target": "x0", "cont_clamp_x0": 5.0, + "use_quantile_transform": true, + "quantile_bins": 1001, + "cont_bound_mode": "sigmoid", "shuffle_buffer": 1024, "use_temporal_stage1": true, "temporal_hidden_dim": 512, diff --git a/example/data_utils.py b/example/data_utils.py index aa36195..3ca3cb3 100755 --- a/example/data_utils.py +++ b/example/data_utils.py @@ -138,6 +138,7 @@ def compute_cont_stats( cont_cols: List[str], max_rows: Optional[int] = None, transforms: Optional[Dict[str, str]] = None, + quantile_bins: Optional[int] = None, ): """Compute stats on (optionally transformed) values. Returns raw + transformed stats.""" # First pass (raw) for metadata and raw mean/std @@ -147,10 +148,11 @@ def compute_cont_stats( if transforms is None: transforms = {c: "none" for c in cont_cols} - # Second pass for transformed mean/std + # Second pass for transformed mean/std (and optional quantiles) count = {c: 0 for c in cont_cols} mean = {c: 0.0 for c in cont_cols} m2 = {c: 0.0 for c in cont_cols} + quantile_values = {c: [] for c in cont_cols} if quantile_bins and quantile_bins > 1 else None for i, row in enumerate(iter_rows(path)): for c in cont_cols: raw_val = row[c] @@ -161,6 +163,8 @@ def compute_cont_stats( if x < 0: x = 0.0 x = math.log1p(x) + if quantile_values is not None: + quantile_values[c].append(x) n = count[c] + 1 delta = x - mean[c] mean[c] += delta / n @@ -178,6 +182,25 @@ def compute_cont_stats( var = 0.0 std[c] = var ** 0.5 if var > 0 else 1.0 + quantile_probs = None + quantile_table = None + if quantile_values is not None: + quantile_probs = [i / (quantile_bins - 1) for i in range(quantile_bins)] + quantile_table = {} + for c in cont_cols: + vals = quantile_values[c] + if not vals: + quantile_table[c] = [0.0 for _ in quantile_probs] + continue + vals.sort() + n = len(vals) + qvals = [] + for p in quantile_probs: + idx = int(round(p * (n - 1))) + idx = max(0, min(n - 1, idx)) + qvals.append(float(vals[idx])) + quantile_table[c] = qvals + return { "mean": mean, "std": std, @@ -191,6 +214,8 @@ def compute_cont_stats( "skew": raw["skew"], "all_pos": raw["all_pos"], "max_rows": max_rows, + "quantile_probs": quantile_probs, + "quantile_values": quantile_table, } @@ -249,6 +274,9 @@ def normalize_cont( mean: Dict[str, float], std: Dict[str, float], transforms: Optional[Dict[str, str]] = None, + quantile_probs: Optional[List[float]] = None, + quantile_values: Optional[Dict[str, List[float]]] = None, + use_quantile: bool = False, ): import torch @@ -256,11 +284,64 @@ def normalize_cont( for i, c in enumerate(cont_cols): if transforms.get(c) == "log1p": x[:, :, i] = torch.log1p(torch.clamp(x[:, :, i], min=0)) + if use_quantile: + if not quantile_probs or not quantile_values: + raise ValueError("use_quantile_transform enabled but quantile stats missing") + x = apply_quantile_transform(x, cont_cols, quantile_probs, quantile_values) mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device) std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device) return (x - mean_t) / std_t +def _normal_cdf(x): + import torch + return 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def _normal_ppf(p): + import torch + eps = 1e-6 + p = torch.clamp(p, eps, 1.0 - eps) + return math.sqrt(2.0) * torch.erfinv(2.0 * p - 1.0) + + +def apply_quantile_transform(x, cont_cols, quantile_probs, quantile_values): + import torch + probs_t = torch.tensor(quantile_probs, dtype=x.dtype, device=x.device) + for i, c in enumerate(cont_cols): + q_vals = torch.tensor(quantile_values[c], dtype=x.dtype, device=x.device) + v = x[:, :, i] + idx = torch.bucketize(v, q_vals) + idx = torch.clamp(idx, 1, q_vals.numel() - 1) + x0 = q_vals[idx - 1] + x1 = q_vals[idx] + p0 = probs_t[idx - 1] + p1 = probs_t[idx] + denom = torch.where((x1 - x0) == 0, torch.ones_like(x1 - x0), (x1 - x0)) + p = p0 + (v - x0) * (p1 - p0) / denom + x[:, :, i] = _normal_ppf(p) + return x + + +def inverse_quantile_transform(x, cont_cols, quantile_probs, quantile_values): + import torch + probs_t = torch.tensor(quantile_probs, dtype=x.dtype, device=x.device) + for i, c in enumerate(cont_cols): + q_vals = torch.tensor(quantile_values[c], dtype=x.dtype, device=x.device) + z = x[:, :, i] + p = _normal_cdf(z) + idx = torch.bucketize(p, probs_t) + idx = torch.clamp(idx, 1, probs_t.numel() - 1) + p0 = probs_t[idx - 1] + p1 = probs_t[idx] + x0 = q_vals[idx - 1] + x1 = q_vals[idx] + denom = torch.where((p1 - p0) == 0, torch.ones_like(p1 - p0), (p1 - p0)) + v = x0 + (p - p0) * (x1 - x0) / denom + x[:, :, i] = v + return x + + def windowed_batches( path: Union[str, List[str]], cont_cols: List[str], @@ -273,6 +354,9 @@ def windowed_batches( max_batches: Optional[int] = None, return_file_id: bool = False, transforms: Optional[Dict[str, str]] = None, + quantile_probs: Optional[List[float]] = None, + quantile_values: Optional[Dict[str, List[float]]] = None, + use_quantile: bool = False, shuffle_buffer: int = 0, ): import torch @@ -316,7 +400,16 @@ def windowed_batches( if len(batch_cont) == batch_size: x_cont = torch.tensor(batch_cont, dtype=torch.float32) x_disc = torch.tensor(batch_disc, dtype=torch.long) - x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms) + x_cont = normalize_cont( + x_cont, + cont_cols, + mean, + std, + transforms=transforms, + quantile_probs=quantile_probs, + quantile_values=quantile_values, + use_quantile=use_quantile, + ) if return_file_id: x_file = torch.tensor(batch_file, dtype=torch.long) yield x_cont, x_disc, x_file @@ -344,7 +437,16 @@ def windowed_batches( import torch x_cont = torch.tensor(batch_cont, dtype=torch.float32) x_disc = torch.tensor(batch_disc, dtype=torch.long) - x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms) + x_cont = normalize_cont( + x_cont, + cont_cols, + mean, + std, + transforms=transforms, + quantile_probs=quantile_probs, + quantile_values=quantile_values, + use_quantile=use_quantile, + ) if return_file_id: x_file = torch.tensor(batch_file, dtype=torch.long) yield x_cont, x_disc, x_file diff --git a/example/export_samples.py b/example/export_samples.py index 3959ca6..f9ca0a1 100644 --- a/example/export_samples.py +++ b/example/export_samples.py @@ -12,7 +12,7 @@ from typing import Dict, List import torch import torch.nn.functional as F -from data_utils import load_split +from data_utils import load_split, inverse_quantile_transform from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path @@ -112,6 +112,8 @@ def main(): int_like = stats.get("int_like", {}) max_decimals = stats.get("max_decimals", {}) transforms = stats.get("transform", {}) + quantile_probs = stats.get("quantile_probs") + quantile_values = stats.get("quantile_values") vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8")) vocab = vocab_json["vocab"] @@ -140,6 +142,8 @@ def main(): raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob) cont_target = str(cfg.get("cont_target", "eps")) cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0)) + use_quantile = bool(cfg.get("use_quantile_transform", False)) + cont_bound_mode = str(cfg.get("cont_bound_mode", "clamp")) use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False)) temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256)) temporal_num_layers = int(cfg.get("temporal_num_layers", 1)) @@ -270,15 +274,21 @@ def main(): mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype) std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype) x_cont = x_cont * std_vec + mean_vec + if use_quantile: + x_cont = inverse_quantile_transform(x_cont, cont_cols, quantile_probs, quantile_values) for i, c in enumerate(cont_cols): if transforms.get(c) == "log1p": x_cont[:, :, i] = torch.expm1(x_cont[:, :, i]) - # clamp to observed min/max per feature + # bound to observed min/max per feature if vmin and vmax: for i, c in enumerate(cont_cols): lo = vmin.get(c, None) hi = vmax.get(c, None) - if lo is not None and hi is not None: + if lo is None or hi is None: + continue + if cont_bound_mode == "sigmoid": + x_cont[:, :, i] = float(lo) + (float(hi) - float(lo)) * torch.sigmoid(x_cont[:, :, i]) + else: x_cont[:, :, i] = torch.clamp(x_cont[:, :, i], float(lo), float(hi)) header = read_header(data_path) diff --git a/example/prepare_data.py b/example/prepare_data.py index e0427ac..3c6b5fc 100755 --- a/example/prepare_data.py +++ b/example/prepare_data.py @@ -17,6 +17,14 @@ OUT_VOCAB = BASE_DIR / "results" / "disc_vocab.json" def main(max_rows: Optional[int] = None): + config_path = BASE_DIR / "config.json" + use_quantile = False + quantile_bins = None + if config_path.exists(): + cfg = json.loads(config_path.read_text(encoding="utf-8")) + use_quantile = bool(cfg.get("use_quantile_transform", False)) + quantile_bins = int(cfg.get("quantile_bins", 0)) if use_quantile else None + split = load_split(safe_path(SPLIT_PATH)) time_col = split.get("time_column", "time") cont_cols = [c for c in split["continuous"] if c != time_col] @@ -28,7 +36,13 @@ def main(max_rows: Optional[int] = None): data_paths = [safe_path(p) for p in data_paths] transforms, _ = choose_cont_transforms(data_paths, cont_cols, max_rows=max_rows) - cont_stats = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows, transforms=transforms) + cont_stats = compute_cont_stats( + data_paths, + cont_cols, + max_rows=max_rows, + transforms=transforms, + quantile_bins=quantile_bins, + ) vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows) ensure_dir(OUT_STATS.parent) @@ -46,6 +60,8 @@ def main(max_rows: Optional[int] = None): "transform": cont_stats["transform"], "skew": cont_stats["skew"], "max_rows": cont_stats["max_rows"], + "quantile_probs": cont_stats["quantile_probs"], + "quantile_values": cont_stats["quantile_values"], }, f, indent=2, diff --git a/example/results/cdf_P1_B3004.svg b/example/results/cdf_P1_B3004.svg new file mode 100644 index 0000000..9dda8f1 --- /dev/null +++ b/example/results/cdf_P1_B3004.svg @@ -0,0 +1,12 @@ + + +CDF 비교: P1_B3004 + + + + +real +generated + + + \ No newline at end of file diff --git a/example/results/cdf_P1_LIT01.svg b/example/results/cdf_P1_LIT01.svg new file mode 100644 index 0000000..3179261 --- /dev/null +++ b/example/results/cdf_P1_LIT01.svg @@ -0,0 +1,12 @@ + + +CDF 비교: P1_LIT01 + + + + +real +generated + + + \ No newline at end of file diff --git a/example/results/cdf_P1_PCV02Z.svg b/example/results/cdf_P1_PCV02Z.svg new file mode 100644 index 0000000..325b214 --- /dev/null +++ b/example/results/cdf_P1_PCV02Z.svg @@ -0,0 +1,12 @@ + + +CDF 비교: P1_PCV02Z + + + + +real +generated + + + \ No newline at end of file diff --git a/example/results/cdf_P2_MSD.svg b/example/results/cdf_P2_MSD.svg new file mode 100644 index 0000000..c6e374f --- /dev/null +++ b/example/results/cdf_P2_MSD.svg @@ -0,0 +1,12 @@ + + +CDF 비교: P2_MSD + + + + +real +generated + + + \ No newline at end of file diff --git a/example/results/cdf_P2_SIT01.svg b/example/results/cdf_P2_SIT01.svg new file mode 100644 index 0000000..3371a7f --- /dev/null +++ b/example/results/cdf_P2_SIT01.svg @@ -0,0 +1,12 @@ + + +CDF 비교: P2_SIT01 + + + + +real +generated + + + \ No newline at end of file diff --git a/example/results/cdf_P2_SIT02.svg b/example/results/cdf_P2_SIT02.svg new file mode 100644 index 0000000..f5904db --- /dev/null +++ b/example/results/cdf_P2_SIT02.svg @@ -0,0 +1,12 @@ + + +CDF 비교: P2_SIT02 + + + + +real +generated + + + \ No newline at end of file diff --git a/example/results/cdf_P3_LCP01D.svg b/example/results/cdf_P3_LCP01D.svg new file mode 100644 index 0000000..1006e21 --- /dev/null +++ b/example/results/cdf_P3_LCP01D.svg @@ -0,0 +1,12 @@ + + +CDF 비교: P3_LCP01D + + + + +real +generated + + + \ No newline at end of file diff --git a/example/results/cdf_P3_PIT01.svg b/example/results/cdf_P3_PIT01.svg new file mode 100644 index 0000000..fffb73a --- /dev/null +++ b/example/results/cdf_P3_PIT01.svg @@ -0,0 +1,12 @@ + + +CDF 비교: P3_PIT01 + + + + +real +generated + + + \ No newline at end of file diff --git a/example/results/cdf_P4_HT_FD.svg b/example/results/cdf_P4_HT_FD.svg new file mode 100644 index 0000000..016f803 --- /dev/null +++ b/example/results/cdf_P4_HT_FD.svg @@ -0,0 +1,12 @@ + + +CDF 비교: P4_HT_FD + + + + +real +generated + + + \ No newline at end of file diff --git a/example/results/cdf_P4_ST_PT01.svg b/example/results/cdf_P4_ST_PT01.svg new file mode 100644 index 0000000..5c30b49 --- /dev/null +++ b/example/results/cdf_P4_ST_PT01.svg @@ -0,0 +1,12 @@ + + +CDF 비교: P4_ST_PT01 + + + + +real +generated + + + \ No newline at end of file diff --git a/example/results/ks_diagnosis.csv b/example/results/ks_diagnosis.csv new file mode 100644 index 0000000..864fa32 --- /dev/null +++ b/example/results/ks_diagnosis.csv @@ -0,0 +1 @@ +feature,ks,boundary_frac,mean_shift,std_ratio,diagnosis,gen_frac_at_min,gen_frac_at_max diff --git a/example/results/ks_per_feature.csv b/example/results/ks_per_feature.csv new file mode 100644 index 0000000..c1ddc23 --- /dev/null +++ b/example/results/ks_per_feature.csv @@ -0,0 +1,54 @@ +feature,ks,gen_frac_at_min,gen_frac_at_max,real_n,gen_n,real_min,real_max +P2_MSD,1.0,1.0,1.0,92163,52,763.19324,763.19324 +P3_PIT01,0.9141619071227483,0.0,0.0,92163,52,-24.0,3847.0 +P2_SIT02,0.8628397930422604,0.0,0.0,92163,52,757.68005,826.50775 +P2_SIT01,0.8617182433464456,0.0,0.0,92163,52,758.0,827.0 +P3_LCP01D,0.8261961040597803,0.8269230769230769,0.0,92163,52,-8.0,13816.0 +P4_HT_FD,0.7983631008272134,0.0,0.0,92163,52,-0.0217,0.02684 +P1_B3004,0.7794726567227461,0.0,0.0,92163,52,369.75601,447.83438 +P1_LIT01,0.7761347161675927,0.0,0.0,92163,52,356.09085,459.24484 +P4_ST_PT01,0.7676921073783155,0.0,0.0,92163,52,9914.0,10330.0 +P1_PCV02Z,0.7670214728253204,0.0,0.0,92163,52,11.76605,12.04071 +P1_PIT02,0.7347702941026726,0.0,0.0,92163,52,0.17105,2.34161 +P4_ST_PO,0.7212397099119537,0.019230769230769232,0.0,92163,52,233.66968,498.60754 +P1_B2016,0.6999296397102459,0.0,0.0,92163,52,0.9508,2.0523 +P4_ST_LD,0.6933532896148046,0.0,0.0,92163,52,230.55914,499.62018 +P4_LD,0.6897361614330463,0.0,0.0,92163,52,231.33685,498.58942 +P3_LIT01,0.6615471835435378,0.0,0.0,92163,52,5047.0,19680.0 +P1_PCV01D,0.6231695265662259,0.0,0.4807692307692308,92163,52,24.95222,100.0 +P1_B2004,0.617741226038482,0.0,0.0,92163,52,0.02978,0.10196 +P2_CO_rpm,0.6100514640031582,0.0,0.0,92163,52,53993.0,54183.0 +P4_ST_GOV,0.6084888062037244,0.0,0.0,92163,52,12665.0,26898.0 +P1_FCV02Z,0.5961538461538461,0.5961538461538461,0.0,92163,52,-1.89057,97.38312 +P1_B4002,0.5783991406529735,0.0,0.0,92163,52,31.41343,33.6555 +P1_FT01Z,0.5633543078775981,0.0,0.0,92163,52,0.0,1365.69287 +P1_PCV01Z,0.547708324465266,0.0,0.5,92163,52,25.57526,100.0 +P1_B3005,0.5359248538751159,0.0,0.0,92163,52,890.07843,1121.94116 +P1_B4005,0.5101396438918004,0.019230769230769232,0.0,92163,52,0.0,100.0 +P1_FT02,0.5049748814600219,0.0,0.0,92163,52,4.99723,2005.23364 +P3_FIT01,0.497898998346575,0.0,0.0,92163,52,-27.0,5421.0 +P2_24Vdc,0.4871763572733593,0.0,0.0,92163,52,28.01351,28.04294 +P4_HT_LD,0.48082202185258727,0.6346153846153846,0.0,92163,52,-0.00723,83.04398 +P1_B400B,0.4544694642184959,0.0,0.0,92163,52,25.02598,2855.56567 +P2_VXT03,0.45055916816276176,0.0,0.0,92163,52,-2.135,0.1491 +P2_VYT03,0.4479521650186668,0.0,0.0,92163,52,4.6083,7.2547 +P2_VXT02,0.44536394131133883,0.0,0.0,92163,52,-4.3925,-1.8818 +P4_HT_PO,0.42936573912941867,0.019230769230769232,0.0,92163,52,0.05423,83.04401 +P3_LCV01D,0.4154990030205681,0.25,0.0,92163,52,-288.0,17776.0 +P1_FT03,0.41513384730565156,0.0,0.0,92163,52,187.91197,331.15381 +P1_FT02Z,0.40720829900869615,0.0,0.0,92163,52,25.02598,2856.88574 +P2_VT01,0.36856126144398005,0.0,0.0,92163,52,11.76163,12.06125 +P1_TIT02,0.3579625646534276,0.0,0.0,92163,52,34.99451,40.4419 +P1_FCV03Z,0.35665363791075844,0.0,0.0,92163,52,46.20513,75.3189 +P1_LCV01Z,0.3512624789357317,0.0,0.0,92163,52,0.29907,28.52783 +P1_FCV03D,0.30470616858592514,0.0,0.0,92163,52,45.78336,74.1622 +P4_ST_TT01,0.30430700122441934,0.0,0.21153846153846154,92163,52,27539.0,27629.0 +P2_HILout,0.3041348563873872,0.0,0.0,92163,52,673.80371,768.76831 +P4_ST_FD,0.30162947086224323,0.0,0.0,92163,52,-0.05244,0.05035 +P1_B4022,0.2862201083531769,0.0,0.0,92163,52,34.21529,38.63682 +P1_TIT01,0.2807849220319517,0.0,0.0,92163,52,34.68933,36.94763 +P1_LCV01D,0.28024261363019864,0.0,0.0,92163,52,3.17127,28.23791 +P1_FT03Z,0.24018503170386246,0.0,0.0,92163,52,867.43927,1146.92163 +P1_PIT01,0.21846515245981413,0.0,0.0,92163,52,0.88211,2.38739 +P1_FT01,0.21452397466361856,0.0,0.0,92163,52,-9.88007,462.57019 +P2_VYT02,0.18998029411101902,0.0,0.0,92163,52,2.4459,5.1248 diff --git a/example/results/ks_summary.json b/example/results/ks_summary.json new file mode 100644 index 0000000..9fa6b90 --- /dev/null +++ b/example/results/ks_summary.json @@ -0,0 +1,17 @@ +{ + "generated_rows": 52, + "reference_rows_per_file": 50000, + "stride": 10, + "top_k_features": [ + "P2_MSD", + "P3_PIT01", + "P2_SIT02", + "P2_SIT01", + "P3_LCP01D", + "P4_HT_FD", + "P1_B3004", + "P1_LIT01", + "P4_ST_PT01", + "P1_PCV02Z" + ] +} \ No newline at end of file diff --git a/example/train.py b/example/train.py index 4a98c8d..3e27a29 100755 --- a/example/train.py +++ b/example/train.py @@ -173,6 +173,9 @@ def main(): std = stats["std"] transforms = stats.get("transform", {}) raw_std = stats.get("raw_std", std) + quantile_probs = stats.get("quantile_probs") + quantile_values = stats.get("quantile_values") + use_quantile = bool(config.get("use_quantile_transform", False)) vocab = load_json(config["vocab_path"])["vocab"] vocab_sizes = [len(vocab[c]) for c in disc_cols] @@ -253,6 +256,9 @@ def main(): max_batches=int(config["max_batches"]), return_file_id=False, transforms=transforms, + quantile_probs=quantile_probs, + quantile_values=quantile_values, + use_quantile=use_quantile, shuffle_buffer=int(config.get("shuffle_buffer", 0)), ) ): @@ -284,6 +290,9 @@ def main(): max_batches=int(config["max_batches"]), return_file_id=use_condition, transforms=transforms, + quantile_probs=quantile_probs, + quantile_values=quantile_values, + use_quantile=use_quantile, shuffle_buffer=int(config.get("shuffle_buffer", 0)), ) ): diff --git a/report.md b/report.md index 64e78a5..356cd3c 100644 --- a/report.md +++ b/report.md @@ -144,6 +144,7 @@ Defined in `example/data_utils.py` + `example/prepare_data.py`. Key steps: - Streaming mean/std/min/max + int-like detection - Optional **log1p transform** for heavy-tailed continuous columns +- Optional **quantile transform** (TabDDPM-style) for continuous columns - Discrete vocab + most frequent token - Windowed batching with **shuffle buffer** @@ -159,7 +160,8 @@ Export process: - Diffusion generates residuals - Output: `trend + residual` - De-normalize continuous values -- Clamp to observed min/max +- Inverse quantile transform (if enabled) +- Bound to observed min/max (clamp or sigmoid mapping) - Restore discrete tokens from vocab - Write to CSV