This commit is contained in:
2026-01-27 18:39:24 +08:00
parent c46c25d607
commit a24c60c506
22 changed files with 357 additions and 8 deletions

View File

@@ -12,7 +12,7 @@ from typing import Dict, List
import torch
import torch.nn.functional as F
from data_utils import load_split
from data_utils import load_split, inverse_quantile_transform
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
@@ -112,6 +112,8 @@ def main():
int_like = stats.get("int_like", {})
max_decimals = stats.get("max_decimals", {})
transforms = stats.get("transform", {})
quantile_probs = stats.get("quantile_probs")
quantile_values = stats.get("quantile_values")
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
vocab = vocab_json["vocab"]
@@ -140,6 +142,8 @@ def main():
raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob)
cont_target = str(cfg.get("cont_target", "eps"))
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
use_quantile = bool(cfg.get("use_quantile_transform", False))
cont_bound_mode = str(cfg.get("cont_bound_mode", "clamp"))
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
@@ -270,15 +274,21 @@ def main():
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
x_cont = x_cont * std_vec + mean_vec
if use_quantile:
x_cont = inverse_quantile_transform(x_cont, cont_cols, quantile_probs, quantile_values)
for i, c in enumerate(cont_cols):
if transforms.get(c) == "log1p":
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])
# clamp to observed min/max per feature
# bound to observed min/max per feature
if vmin and vmax:
for i, c in enumerate(cont_cols):
lo = vmin.get(c, None)
hi = vmax.get(c, None)
if lo is not None and hi is not None:
if lo is None or hi is None:
continue
if cont_bound_mode == "sigmoid":
x_cont[:, :, i] = float(lo) + (float(hi) - float(lo)) * torch.sigmoid(x_cont[:, :, i])
else:
x_cont[:, :, i] = torch.clamp(x_cont[:, :, i], float(lo), float(hi))
header = read_header(data_path)