Add full quantile stats and post-hoc calibration
This commit is contained in:
@@ -14,3 +14,5 @@ Conventions:
|
|||||||
Tools:
|
Tools:
|
||||||
- `example/diagnose_ks.py` for per-feature KS + CDF plots.
|
- `example/diagnose_ks.py` for per-feature KS + CDF plots.
|
||||||
- `example/run_all_full.py` for one-command full pipeline + diagnostics.
|
- `example/run_all_full.py` for one-command full pipeline + diagnostics.
|
||||||
|
Notes:
|
||||||
|
- If `use_quantile_transform` is enabled, run `prepare_data.py` with `full_stats: true` to build quantile tables.
|
||||||
|
|||||||
@@ -62,3 +62,12 @@
|
|||||||
- **Files**:
|
- **Files**:
|
||||||
- `example/export_samples.py`
|
- `example/export_samples.py`
|
||||||
- `example/config.json`
|
- `example/config.json`
|
||||||
|
|
||||||
|
## 2026-01-27 — Post-hoc quantile calibration
|
||||||
|
- **Decision**: Add optional post-hoc quantile calibration to align generated 1D CDFs with real data.
|
||||||
|
- **Why**: KS remained high with distribution shifts even after boundary fixes.
|
||||||
|
- **Files**:
|
||||||
|
- `example/data_utils.py`
|
||||||
|
- `example/export_samples.py`
|
||||||
|
- `example/prepare_data.py`
|
||||||
|
- `example/config.json`
|
||||||
|
|||||||
@@ -44,21 +44,11 @@
|
|||||||
"cont_clamp_x0": 5.0,
|
"cont_clamp_x0": 5.0,
|
||||||
"use_quantile_transform": true,
|
"use_quantile_transform": true,
|
||||||
"quantile_bins": 1001,
|
"quantile_bins": 1001,
|
||||||
"cont_bound_mode": "soft_tanh",
|
"cont_bound_mode": "none",
|
||||||
"cont_bound_strength": 2.0,
|
"cont_bound_strength": 2.0,
|
||||||
"cont_post_scale": {
|
"cont_post_calibrate": true,
|
||||||
"P1_B4002": 0.8,
|
"cont_post_scale": {},
|
||||||
"P1_B400B": 0.8,
|
"full_stats": true,
|
||||||
"P1_FT02Z": 0.8,
|
|
||||||
"P1_PCV01D": 0.8,
|
|
||||||
"P1_PCV01Z": 0.8,
|
|
||||||
"P1_PCV02Z": 0.8,
|
|
||||||
"P2_24Vdc": 0.8,
|
|
||||||
"P2_MSD": 0.8,
|
|
||||||
"P3_LCP01D": 0.8,
|
|
||||||
"P4_ST_PT01": 0.8,
|
|
||||||
"P4_ST_TT01": 0.8
|
|
||||||
},
|
|
||||||
"shuffle_buffer": 256,
|
"shuffle_buffer": 256,
|
||||||
"use_temporal_stage1": true,
|
"use_temporal_stage1": true,
|
||||||
"temporal_hidden_dim": 256,
|
"temporal_hidden_dim": 256,
|
||||||
|
|||||||
@@ -44,21 +44,11 @@
|
|||||||
"cont_clamp_x0": 5.0,
|
"cont_clamp_x0": 5.0,
|
||||||
"use_quantile_transform": true,
|
"use_quantile_transform": true,
|
||||||
"quantile_bins": 1001,
|
"quantile_bins": 1001,
|
||||||
"cont_bound_mode": "soft_tanh",
|
"cont_bound_mode": "none",
|
||||||
"cont_bound_strength": 2.0,
|
"cont_bound_strength": 2.0,
|
||||||
"cont_post_scale": {
|
"cont_post_calibrate": true,
|
||||||
"P1_B4002": 0.8,
|
"cont_post_scale": {},
|
||||||
"P1_B400B": 0.8,
|
"full_stats": true,
|
||||||
"P1_FT02Z": 0.8,
|
|
||||||
"P1_PCV01D": 0.8,
|
|
||||||
"P1_PCV01Z": 0.8,
|
|
||||||
"P1_PCV02Z": 0.8,
|
|
||||||
"P2_24Vdc": 0.8,
|
|
||||||
"P2_MSD": 0.8,
|
|
||||||
"P3_LCP01D": 0.8,
|
|
||||||
"P4_ST_PT01": 0.8,
|
|
||||||
"P4_ST_TT01": 0.8
|
|
||||||
},
|
|
||||||
"shuffle_buffer": 1024,
|
"shuffle_buffer": 1024,
|
||||||
"use_temporal_stage1": false,
|
"use_temporal_stage1": false,
|
||||||
"sample_batch_size": 4,
|
"sample_batch_size": 4,
|
||||||
|
|||||||
@@ -44,21 +44,11 @@
|
|||||||
"cont_clamp_x0": 5.0,
|
"cont_clamp_x0": 5.0,
|
||||||
"use_quantile_transform": true,
|
"use_quantile_transform": true,
|
||||||
"quantile_bins": 1001,
|
"quantile_bins": 1001,
|
||||||
"cont_bound_mode": "soft_tanh",
|
"cont_bound_mode": "none",
|
||||||
"cont_bound_strength": 2.0,
|
"cont_bound_strength": 2.0,
|
||||||
"cont_post_scale": {
|
"cont_post_calibrate": true,
|
||||||
"P1_B4002": 0.8,
|
"cont_post_scale": {},
|
||||||
"P1_B400B": 0.8,
|
"full_stats": true,
|
||||||
"P1_FT02Z": 0.8,
|
|
||||||
"P1_PCV01D": 0.8,
|
|
||||||
"P1_PCV01Z": 0.8,
|
|
||||||
"P1_PCV02Z": 0.8,
|
|
||||||
"P2_24Vdc": 0.8,
|
|
||||||
"P2_MSD": 0.8,
|
|
||||||
"P3_LCP01D": 0.8,
|
|
||||||
"P4_ST_PT01": 0.8,
|
|
||||||
"P4_ST_TT01": 0.8
|
|
||||||
},
|
|
||||||
"shuffle_buffer": 1024,
|
"shuffle_buffer": 1024,
|
||||||
"use_temporal_stage1": true,
|
"use_temporal_stage1": true,
|
||||||
"temporal_hidden_dim": 512,
|
"temporal_hidden_dim": 512,
|
||||||
|
|||||||
@@ -153,12 +153,15 @@ def compute_cont_stats(
|
|||||||
mean = {c: 0.0 for c in cont_cols}
|
mean = {c: 0.0 for c in cont_cols}
|
||||||
m2 = {c: 0.0 for c in cont_cols}
|
m2 = {c: 0.0 for c in cont_cols}
|
||||||
quantile_values = {c: [] for c in cont_cols} if quantile_bins and quantile_bins > 1 else None
|
quantile_values = {c: [] for c in cont_cols} if quantile_bins and quantile_bins > 1 else None
|
||||||
|
raw_quantile_values = {c: [] for c in cont_cols} if quantile_bins and quantile_bins > 1 else None
|
||||||
for i, row in enumerate(iter_rows(path)):
|
for i, row in enumerate(iter_rows(path)):
|
||||||
for c in cont_cols:
|
for c in cont_cols:
|
||||||
raw_val = row[c]
|
raw_val = row[c]
|
||||||
if raw_val is None or raw_val == "":
|
if raw_val is None or raw_val == "":
|
||||||
continue
|
continue
|
||||||
x = float(raw_val)
|
x = float(raw_val)
|
||||||
|
if raw_quantile_values is not None:
|
||||||
|
raw_quantile_values[c].append(x)
|
||||||
if transforms.get(c) == "log1p":
|
if transforms.get(c) == "log1p":
|
||||||
if x < 0:
|
if x < 0:
|
||||||
x = 0.0
|
x = 0.0
|
||||||
@@ -184,22 +187,36 @@ def compute_cont_stats(
|
|||||||
|
|
||||||
quantile_probs = None
|
quantile_probs = None
|
||||||
quantile_table = None
|
quantile_table = None
|
||||||
|
raw_quantile_table = None
|
||||||
if quantile_values is not None:
|
if quantile_values is not None:
|
||||||
quantile_probs = [i / (quantile_bins - 1) for i in range(quantile_bins)]
|
quantile_probs = [i / (quantile_bins - 1) for i in range(quantile_bins)]
|
||||||
quantile_table = {}
|
quantile_table = {}
|
||||||
|
raw_quantile_table = {}
|
||||||
for c in cont_cols:
|
for c in cont_cols:
|
||||||
vals = quantile_values[c]
|
vals = quantile_values[c]
|
||||||
if not vals:
|
if not vals:
|
||||||
quantile_table[c] = [0.0 for _ in quantile_probs]
|
quantile_table[c] = [0.0 for _ in quantile_probs]
|
||||||
|
else:
|
||||||
|
vals.sort()
|
||||||
|
n = len(vals)
|
||||||
|
qvals = []
|
||||||
|
for p in quantile_probs:
|
||||||
|
idx = int(round(p * (n - 1)))
|
||||||
|
idx = max(0, min(n - 1, idx))
|
||||||
|
qvals.append(float(vals[idx]))
|
||||||
|
quantile_table[c] = qvals
|
||||||
|
raw_vals = raw_quantile_values[c] if raw_quantile_values is not None else []
|
||||||
|
if not raw_vals:
|
||||||
|
raw_quantile_table[c] = [0.0 for _ in quantile_probs]
|
||||||
continue
|
continue
|
||||||
vals.sort()
|
raw_vals.sort()
|
||||||
n = len(vals)
|
n = len(raw_vals)
|
||||||
qvals = []
|
rqvals = []
|
||||||
for p in quantile_probs:
|
for p in quantile_probs:
|
||||||
idx = int(round(p * (n - 1)))
|
idx = int(round(p * (n - 1)))
|
||||||
idx = max(0, min(n - 1, idx))
|
idx = max(0, min(n - 1, idx))
|
||||||
qvals.append(float(vals[idx]))
|
rqvals.append(float(raw_vals[idx]))
|
||||||
quantile_table[c] = qvals
|
raw_quantile_table[c] = rqvals
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"mean": mean,
|
"mean": mean,
|
||||||
@@ -216,6 +233,7 @@ def compute_cont_stats(
|
|||||||
"max_rows": max_rows,
|
"max_rows": max_rows,
|
||||||
"quantile_probs": quantile_probs,
|
"quantile_probs": quantile_probs,
|
||||||
"quantile_values": quantile_table,
|
"quantile_values": quantile_table,
|
||||||
|
"quantile_raw_values": raw_quantile_table,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -344,6 +362,35 @@ def inverse_quantile_transform(x, cont_cols, quantile_probs, quantile_values):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def quantile_calibrate_to_real(x, cont_cols, quantile_probs, real_quantile_values):
|
||||||
|
import torch
|
||||||
|
probs_t = torch.tensor(quantile_probs, dtype=x.dtype, device=x.device)
|
||||||
|
flat = x.reshape(-1, x.size(-1))
|
||||||
|
for i, c in enumerate(cont_cols):
|
||||||
|
v = flat[:, i]
|
||||||
|
gen_q = torch.quantile(v, probs_t)
|
||||||
|
idx = torch.bucketize(v, gen_q)
|
||||||
|
idx = torch.clamp(idx, 1, gen_q.numel() - 1)
|
||||||
|
x0 = gen_q[idx - 1]
|
||||||
|
x1 = gen_q[idx]
|
||||||
|
p0 = probs_t[idx - 1]
|
||||||
|
p1 = probs_t[idx]
|
||||||
|
denom = torch.where((x1 - x0) == 0, torch.ones_like(x1 - x0), (x1 - x0))
|
||||||
|
p = p0 + (v - x0) * (p1 - p0) / denom
|
||||||
|
|
||||||
|
real_q = torch.tensor(real_quantile_values[c], dtype=x.dtype, device=x.device)
|
||||||
|
idx2 = torch.bucketize(p, probs_t)
|
||||||
|
idx2 = torch.clamp(idx2, 1, probs_t.numel() - 1)
|
||||||
|
rp0 = probs_t[idx2 - 1]
|
||||||
|
rp1 = probs_t[idx2]
|
||||||
|
r0 = real_q[idx2 - 1]
|
||||||
|
r1 = real_q[idx2]
|
||||||
|
denom2 = torch.where((rp1 - rp0) == 0, torch.ones_like(rp1 - rp0), (rp1 - rp0))
|
||||||
|
v2 = r0 + (p - rp0) * (r1 - r0) / denom2
|
||||||
|
flat[:, i] = v2
|
||||||
|
return flat.reshape(x.shape)
|
||||||
|
|
||||||
|
|
||||||
def windowed_batches(
|
def windowed_batches(
|
||||||
path: Union[str, List[str]],
|
path: Union[str, List[str]],
|
||||||
cont_cols: List[str],
|
cont_cols: List[str],
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from typing import Dict, List
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from data_utils import load_split, inverse_quantile_transform
|
from data_utils import load_split, inverse_quantile_transform, quantile_calibrate_to_real
|
||||||
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
|
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
|
||||||
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
|
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
|
||||||
|
|
||||||
@@ -114,6 +114,7 @@ def main():
|
|||||||
transforms = stats.get("transform", {})
|
transforms = stats.get("transform", {})
|
||||||
quantile_probs = stats.get("quantile_probs")
|
quantile_probs = stats.get("quantile_probs")
|
||||||
quantile_values = stats.get("quantile_values")
|
quantile_values = stats.get("quantile_values")
|
||||||
|
quantile_raw_values = stats.get("quantile_raw_values")
|
||||||
|
|
||||||
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
|
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
|
||||||
vocab = vocab_json["vocab"]
|
vocab = vocab_json["vocab"]
|
||||||
@@ -146,6 +147,7 @@ def main():
|
|||||||
cont_bound_mode = str(cfg.get("cont_bound_mode", "clamp"))
|
cont_bound_mode = str(cfg.get("cont_bound_mode", "clamp"))
|
||||||
cont_bound_strength = float(cfg.get("cont_bound_strength", 1.0))
|
cont_bound_strength = float(cfg.get("cont_bound_strength", 1.0))
|
||||||
cont_post_scale = cfg.get("cont_post_scale", {}) if isinstance(cfg.get("cont_post_scale", {}), dict) else {}
|
cont_post_scale = cfg.get("cont_post_scale", {}) if isinstance(cfg.get("cont_post_scale", {}), dict) else {}
|
||||||
|
cont_post_calibrate = bool(cfg.get("cont_post_calibrate", False))
|
||||||
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
|
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
|
||||||
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
|
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
|
||||||
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
|
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
|
||||||
@@ -282,6 +284,8 @@ def main():
|
|||||||
for i, c in enumerate(cont_cols):
|
for i, c in enumerate(cont_cols):
|
||||||
if transforms.get(c) == "log1p":
|
if transforms.get(c) == "log1p":
|
||||||
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])
|
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])
|
||||||
|
if cont_post_calibrate and quantile_raw_values and quantile_probs:
|
||||||
|
x_cont = quantile_calibrate_to_real(x_cont, cont_cols, quantile_probs, quantile_raw_values)
|
||||||
# bound to observed min/max per feature
|
# bound to observed min/max per feature
|
||||||
if vmin and vmax:
|
if vmin and vmax:
|
||||||
for i, c in enumerate(cont_cols):
|
for i, c in enumerate(cont_cols):
|
||||||
@@ -291,6 +295,8 @@ def main():
|
|||||||
continue
|
continue
|
||||||
lo = float(lo)
|
lo = float(lo)
|
||||||
hi = float(hi)
|
hi = float(hi)
|
||||||
|
if cont_bound_mode == "none":
|
||||||
|
continue
|
||||||
if cont_bound_mode == "sigmoid":
|
if cont_bound_mode == "sigmoid":
|
||||||
x_cont[:, :, i] = lo + (hi - lo) * torch.sigmoid(x_cont[:, :, i])
|
x_cont[:, :, i] = lo + (hi - lo) * torch.sigmoid(x_cont[:, :, i])
|
||||||
elif cont_bound_mode == "soft_tanh":
|
elif cont_bound_mode == "soft_tanh":
|
||||||
|
|||||||
@@ -20,10 +20,15 @@ def main(max_rows: Optional[int] = None):
|
|||||||
config_path = BASE_DIR / "config.json"
|
config_path = BASE_DIR / "config.json"
|
||||||
use_quantile = False
|
use_quantile = False
|
||||||
quantile_bins = None
|
quantile_bins = None
|
||||||
|
full_stats = False
|
||||||
if config_path.exists():
|
if config_path.exists():
|
||||||
cfg = json.loads(config_path.read_text(encoding="utf-8"))
|
cfg = json.loads(config_path.read_text(encoding="utf-8"))
|
||||||
use_quantile = bool(cfg.get("use_quantile_transform", False))
|
use_quantile = bool(cfg.get("use_quantile_transform", False))
|
||||||
quantile_bins = int(cfg.get("quantile_bins", 0)) if use_quantile else None
|
quantile_bins = int(cfg.get("quantile_bins", 0)) if use_quantile else None
|
||||||
|
full_stats = bool(cfg.get("full_stats", False))
|
||||||
|
|
||||||
|
if full_stats:
|
||||||
|
max_rows = None
|
||||||
|
|
||||||
split = load_split(safe_path(SPLIT_PATH))
|
split = load_split(safe_path(SPLIT_PATH))
|
||||||
time_col = split.get("time_column", "time")
|
time_col = split.get("time_column", "time")
|
||||||
@@ -62,6 +67,7 @@ def main(max_rows: Optional[int] = None):
|
|||||||
"max_rows": cont_stats["max_rows"],
|
"max_rows": cont_stats["max_rows"],
|
||||||
"quantile_probs": cont_stats["quantile_probs"],
|
"quantile_probs": cont_stats["quantile_probs"],
|
||||||
"quantile_values": cont_stats["quantile_values"],
|
"quantile_values": cont_stats["quantile_values"],
|
||||||
|
"quantile_raw_values": cont_stats["quantile_raw_values"],
|
||||||
},
|
},
|
||||||
f,
|
f,
|
||||||
indent=2,
|
indent=2,
|
||||||
|
|||||||
@@ -145,6 +145,7 @@ Key steps:
|
|||||||
- Streaming mean/std/min/max + int-like detection
|
- Streaming mean/std/min/max + int-like detection
|
||||||
- Optional **log1p transform** for heavy-tailed continuous columns
|
- Optional **log1p transform** for heavy-tailed continuous columns
|
||||||
- Optional **quantile transform** (TabDDPM-style) for continuous columns (skips extra standardization)
|
- Optional **quantile transform** (TabDDPM-style) for continuous columns (skips extra standardization)
|
||||||
|
- Optional **post-hoc quantile calibration** to align 1D CDFs after sampling
|
||||||
- Discrete vocab + most frequent token
|
- Discrete vocab + most frequent token
|
||||||
- Windowed batching with **shuffle buffer**
|
- Windowed batching with **shuffle buffer**
|
||||||
|
|
||||||
@@ -161,7 +162,8 @@ Export process:
|
|||||||
- Output: `trend + residual`
|
- Output: `trend + residual`
|
||||||
- De-normalize continuous values
|
- De-normalize continuous values
|
||||||
- Inverse quantile transform (if enabled; no extra de-standardization)
|
- Inverse quantile transform (if enabled; no extra de-standardization)
|
||||||
- Bound to observed min/max (clamp or sigmoid mapping)
|
- Optional post-hoc quantile calibration (if enabled)
|
||||||
|
- Bound to observed min/max (clamp / sigmoid / soft_tanh / none)
|
||||||
- Restore discrete tokens from vocab
|
- Restore discrete tokens from vocab
|
||||||
- Write to CSV
|
- Write to CSV
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user