diff --git a/docs/README.md b/docs/README.md index 15d8cd2..fa63456 100644 --- a/docs/README.md +++ b/docs/README.md @@ -14,3 +14,5 @@ Conventions: Tools: - `example/diagnose_ks.py` for per-feature KS + CDF plots. - `example/run_all_full.py` for one-command full pipeline + diagnostics. +Notes: +- If `use_quantile_transform` is enabled, run `prepare_data.py` with `full_stats: true` to build quantile tables. diff --git a/docs/decisions.md b/docs/decisions.md index cd3930b..fbdd004 100644 --- a/docs/decisions.md +++ b/docs/decisions.md @@ -62,3 +62,12 @@ - **Files**: - `example/export_samples.py` - `example/config.json` + +## 2026-01-27 — Post-hoc quantile calibration +- **Decision**: Add optional post-hoc quantile calibration to align generated 1D CDFs with real data. +- **Why**: KS remained high with distribution shifts even after boundary fixes. +- **Files**: + - `example/data_utils.py` + - `example/export_samples.py` + - `example/prepare_data.py` + - `example/config.json` diff --git a/example/config.json b/example/config.json index 30d31b5..9258792 100644 --- a/example/config.json +++ b/example/config.json @@ -44,21 +44,11 @@ "cont_clamp_x0": 5.0, "use_quantile_transform": true, "quantile_bins": 1001, - "cont_bound_mode": "soft_tanh", + "cont_bound_mode": "none", "cont_bound_strength": 2.0, - "cont_post_scale": { - "P1_B4002": 0.8, - "P1_B400B": 0.8, - "P1_FT02Z": 0.8, - "P1_PCV01D": 0.8, - "P1_PCV01Z": 0.8, - "P1_PCV02Z": 0.8, - "P2_24Vdc": 0.8, - "P2_MSD": 0.8, - "P3_LCP01D": 0.8, - "P4_ST_PT01": 0.8, - "P4_ST_TT01": 0.8 - }, + "cont_post_calibrate": true, + "cont_post_scale": {}, + "full_stats": true, "shuffle_buffer": 256, "use_temporal_stage1": true, "temporal_hidden_dim": 256, diff --git a/example/config_no_temporal.json b/example/config_no_temporal.json index 1e6a9a8..b8ea915 100644 --- a/example/config_no_temporal.json +++ b/example/config_no_temporal.json @@ -44,21 +44,11 @@ "cont_clamp_x0": 5.0, "use_quantile_transform": true, "quantile_bins": 1001, - "cont_bound_mode": "soft_tanh", + "cont_bound_mode": "none", "cont_bound_strength": 2.0, - "cont_post_scale": { - "P1_B4002": 0.8, - "P1_B400B": 0.8, - "P1_FT02Z": 0.8, - "P1_PCV01D": 0.8, - "P1_PCV01Z": 0.8, - "P1_PCV02Z": 0.8, - "P2_24Vdc": 0.8, - "P2_MSD": 0.8, - "P3_LCP01D": 0.8, - "P4_ST_PT01": 0.8, - "P4_ST_TT01": 0.8 - }, + "cont_post_calibrate": true, + "cont_post_scale": {}, + "full_stats": true, "shuffle_buffer": 1024, "use_temporal_stage1": false, "sample_batch_size": 4, diff --git a/example/config_temporal_strong.json b/example/config_temporal_strong.json index 7bea30e..b5af569 100644 --- a/example/config_temporal_strong.json +++ b/example/config_temporal_strong.json @@ -44,21 +44,11 @@ "cont_clamp_x0": 5.0, "use_quantile_transform": true, "quantile_bins": 1001, - "cont_bound_mode": "soft_tanh", + "cont_bound_mode": "none", "cont_bound_strength": 2.0, - "cont_post_scale": { - "P1_B4002": 0.8, - "P1_B400B": 0.8, - "P1_FT02Z": 0.8, - "P1_PCV01D": 0.8, - "P1_PCV01Z": 0.8, - "P1_PCV02Z": 0.8, - "P2_24Vdc": 0.8, - "P2_MSD": 0.8, - "P3_LCP01D": 0.8, - "P4_ST_PT01": 0.8, - "P4_ST_TT01": 0.8 - }, + "cont_post_calibrate": true, + "cont_post_scale": {}, + "full_stats": true, "shuffle_buffer": 1024, "use_temporal_stage1": true, "temporal_hidden_dim": 512, diff --git a/example/data_utils.py b/example/data_utils.py index 29c730a..4a01a5e 100755 --- a/example/data_utils.py +++ b/example/data_utils.py @@ -153,12 +153,15 @@ def compute_cont_stats( mean = {c: 0.0 for c in cont_cols} m2 = {c: 0.0 for c in cont_cols} quantile_values = {c: [] for c in cont_cols} if quantile_bins and quantile_bins > 1 else None + raw_quantile_values = {c: [] for c in cont_cols} if quantile_bins and quantile_bins > 1 else None for i, row in enumerate(iter_rows(path)): for c in cont_cols: raw_val = row[c] if raw_val is None or raw_val == "": continue x = float(raw_val) + if raw_quantile_values is not None: + raw_quantile_values[c].append(x) if transforms.get(c) == "log1p": if x < 0: x = 0.0 @@ -184,22 +187,36 @@ def compute_cont_stats( quantile_probs = None quantile_table = None + raw_quantile_table = None if quantile_values is not None: quantile_probs = [i / (quantile_bins - 1) for i in range(quantile_bins)] quantile_table = {} + raw_quantile_table = {} for c in cont_cols: vals = quantile_values[c] if not vals: quantile_table[c] = [0.0 for _ in quantile_probs] + else: + vals.sort() + n = len(vals) + qvals = [] + for p in quantile_probs: + idx = int(round(p * (n - 1))) + idx = max(0, min(n - 1, idx)) + qvals.append(float(vals[idx])) + quantile_table[c] = qvals + raw_vals = raw_quantile_values[c] if raw_quantile_values is not None else [] + if not raw_vals: + raw_quantile_table[c] = [0.0 for _ in quantile_probs] continue - vals.sort() - n = len(vals) - qvals = [] + raw_vals.sort() + n = len(raw_vals) + rqvals = [] for p in quantile_probs: idx = int(round(p * (n - 1))) idx = max(0, min(n - 1, idx)) - qvals.append(float(vals[idx])) - quantile_table[c] = qvals + rqvals.append(float(raw_vals[idx])) + raw_quantile_table[c] = rqvals return { "mean": mean, @@ -216,6 +233,7 @@ def compute_cont_stats( "max_rows": max_rows, "quantile_probs": quantile_probs, "quantile_values": quantile_table, + "quantile_raw_values": raw_quantile_table, } @@ -344,6 +362,35 @@ def inverse_quantile_transform(x, cont_cols, quantile_probs, quantile_values): return x +def quantile_calibrate_to_real(x, cont_cols, quantile_probs, real_quantile_values): + import torch + probs_t = torch.tensor(quantile_probs, dtype=x.dtype, device=x.device) + flat = x.reshape(-1, x.size(-1)) + for i, c in enumerate(cont_cols): + v = flat[:, i] + gen_q = torch.quantile(v, probs_t) + idx = torch.bucketize(v, gen_q) + idx = torch.clamp(idx, 1, gen_q.numel() - 1) + x0 = gen_q[idx - 1] + x1 = gen_q[idx] + p0 = probs_t[idx - 1] + p1 = probs_t[idx] + denom = torch.where((x1 - x0) == 0, torch.ones_like(x1 - x0), (x1 - x0)) + p = p0 + (v - x0) * (p1 - p0) / denom + + real_q = torch.tensor(real_quantile_values[c], dtype=x.dtype, device=x.device) + idx2 = torch.bucketize(p, probs_t) + idx2 = torch.clamp(idx2, 1, probs_t.numel() - 1) + rp0 = probs_t[idx2 - 1] + rp1 = probs_t[idx2] + r0 = real_q[idx2 - 1] + r1 = real_q[idx2] + denom2 = torch.where((rp1 - rp0) == 0, torch.ones_like(rp1 - rp0), (rp1 - rp0)) + v2 = r0 + (p - rp0) * (r1 - r0) / denom2 + flat[:, i] = v2 + return flat.reshape(x.shape) + + def windowed_batches( path: Union[str, List[str]], cont_cols: List[str], diff --git a/example/export_samples.py b/example/export_samples.py index ab59fc3..04b1a89 100644 --- a/example/export_samples.py +++ b/example/export_samples.py @@ -12,7 +12,7 @@ from typing import Dict, List import torch import torch.nn.functional as F -from data_utils import load_split, inverse_quantile_transform +from data_utils import load_split, inverse_quantile_transform, quantile_calibrate_to_real from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path @@ -114,6 +114,7 @@ def main(): transforms = stats.get("transform", {}) quantile_probs = stats.get("quantile_probs") quantile_values = stats.get("quantile_values") + quantile_raw_values = stats.get("quantile_raw_values") vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8")) vocab = vocab_json["vocab"] @@ -146,6 +147,7 @@ def main(): cont_bound_mode = str(cfg.get("cont_bound_mode", "clamp")) cont_bound_strength = float(cfg.get("cont_bound_strength", 1.0)) cont_post_scale = cfg.get("cont_post_scale", {}) if isinstance(cfg.get("cont_post_scale", {}), dict) else {} + cont_post_calibrate = bool(cfg.get("cont_post_calibrate", False)) use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False)) temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256)) temporal_num_layers = int(cfg.get("temporal_num_layers", 1)) @@ -282,6 +284,8 @@ def main(): for i, c in enumerate(cont_cols): if transforms.get(c) == "log1p": x_cont[:, :, i] = torch.expm1(x_cont[:, :, i]) + if cont_post_calibrate and quantile_raw_values and quantile_probs: + x_cont = quantile_calibrate_to_real(x_cont, cont_cols, quantile_probs, quantile_raw_values) # bound to observed min/max per feature if vmin and vmax: for i, c in enumerate(cont_cols): @@ -291,6 +295,8 @@ def main(): continue lo = float(lo) hi = float(hi) + if cont_bound_mode == "none": + continue if cont_bound_mode == "sigmoid": x_cont[:, :, i] = lo + (hi - lo) * torch.sigmoid(x_cont[:, :, i]) elif cont_bound_mode == "soft_tanh": diff --git a/example/prepare_data.py b/example/prepare_data.py index 3c6b5fc..e4e422c 100755 --- a/example/prepare_data.py +++ b/example/prepare_data.py @@ -20,10 +20,15 @@ def main(max_rows: Optional[int] = None): config_path = BASE_DIR / "config.json" use_quantile = False quantile_bins = None + full_stats = False if config_path.exists(): cfg = json.loads(config_path.read_text(encoding="utf-8")) use_quantile = bool(cfg.get("use_quantile_transform", False)) quantile_bins = int(cfg.get("quantile_bins", 0)) if use_quantile else None + full_stats = bool(cfg.get("full_stats", False)) + + if full_stats: + max_rows = None split = load_split(safe_path(SPLIT_PATH)) time_col = split.get("time_column", "time") @@ -62,6 +67,7 @@ def main(max_rows: Optional[int] = None): "max_rows": cont_stats["max_rows"], "quantile_probs": cont_stats["quantile_probs"], "quantile_values": cont_stats["quantile_values"], + "quantile_raw_values": cont_stats["quantile_raw_values"], }, f, indent=2, diff --git a/report.md b/report.md index dde9b5e..e7401d7 100644 --- a/report.md +++ b/report.md @@ -145,6 +145,7 @@ Key steps: - Streaming mean/std/min/max + int-like detection - Optional **log1p transform** for heavy-tailed continuous columns - Optional **quantile transform** (TabDDPM-style) for continuous columns (skips extra standardization) +- Optional **post-hoc quantile calibration** to align 1D CDFs after sampling - Discrete vocab + most frequent token - Windowed batching with **shuffle buffer** @@ -161,7 +162,8 @@ Export process: - Output: `trend + residual` - De-normalize continuous values - Inverse quantile transform (if enabled; no extra de-standardization) -- Bound to observed min/max (clamp or sigmoid mapping) +- Optional post-hoc quantile calibration (if enabled) +- Bound to observed min/max (clamp / sigmoid / soft_tanh / none) - Restore discrete tokens from vocab - Write to CSV