Add full quantile stats and post-hoc calibration

This commit is contained in:
2026-01-28 00:52:42 +08:00
parent 6d5c5fffb1
commit c68a6e3c97
9 changed files with 91 additions and 49 deletions

View File

@@ -12,7 +12,7 @@ from typing import Dict, List
import torch
import torch.nn.functional as F
from data_utils import load_split, inverse_quantile_transform
from data_utils import load_split, inverse_quantile_transform, quantile_calibrate_to_real
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
@@ -114,6 +114,7 @@ def main():
transforms = stats.get("transform", {})
quantile_probs = stats.get("quantile_probs")
quantile_values = stats.get("quantile_values")
quantile_raw_values = stats.get("quantile_raw_values")
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
vocab = vocab_json["vocab"]
@@ -146,6 +147,7 @@ def main():
cont_bound_mode = str(cfg.get("cont_bound_mode", "clamp"))
cont_bound_strength = float(cfg.get("cont_bound_strength", 1.0))
cont_post_scale = cfg.get("cont_post_scale", {}) if isinstance(cfg.get("cont_post_scale", {}), dict) else {}
cont_post_calibrate = bool(cfg.get("cont_post_calibrate", False))
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
@@ -282,6 +284,8 @@ def main():
for i, c in enumerate(cont_cols):
if transforms.get(c) == "log1p":
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])
if cont_post_calibrate and quantile_raw_values and quantile_probs:
x_cont = quantile_calibrate_to_real(x_cont, cont_cols, quantile_probs, quantile_raw_values)
# bound to observed min/max per feature
if vmin and vmax:
for i, c in enumerate(cont_cols):
@@ -291,6 +295,8 @@ def main():
continue
lo = float(lo)
hi = float(hi)
if cont_bound_mode == "none":
continue
if cont_bound_mode == "sigmoid":
x_cont[:, :, i] = lo + (hi - lo) * torch.sigmoid(x_cont[:, :, i])
elif cont_bound_mode == "soft_tanh":