优化6个类,现在ks降低到0.28,史称3.0版本

This commit is contained in:
2026-01-28 20:10:42 +08:00
parent 59697c0640
commit 39eede92f6
28 changed files with 3317 additions and 225 deletions

View File

@@ -148,6 +148,11 @@ def main():
cont_bound_strength = float(cfg.get("cont_bound_strength", 1.0))
cont_post_scale = cfg.get("cont_post_scale", {}) if isinstance(cfg.get("cont_post_scale", {}), dict) else {}
cont_post_calibrate = bool(cfg.get("cont_post_calibrate", False))
type1_cols = cfg.get("type1_features", []) or []
type5_cols = cfg.get("type5_features", []) or []
type1_cols = [c for c in type1_cols if c in cont_cols]
type5_cols = [c for c in type5_cols if c in cont_cols]
model_cont_cols = [c for c in cont_cols if c not in type1_cols and c not in type5_cols]
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
@@ -159,7 +164,7 @@ def main():
transformer_dropout = float(cfg.get("transformer_dropout", 0.1))
model = HybridDiffusionModel(
cont_dim=len(cont_cols),
cont_dim=len(model_cont_cols),
disc_vocab_sizes=vocab_sizes,
time_dim=int(cfg.get("model_time_dim", 64)),
hidden_dim=int(cfg.get("model_hidden_dim", 256)),
@@ -173,6 +178,7 @@ def main():
transformer_nhead=transformer_nhead,
transformer_ff_dim=transformer_ff_dim,
transformer_dropout=transformer_dropout,
cond_cont_dim=len(type1_cols),
cond_vocab_size=cond_vocab_size if use_condition else 0,
cond_dim=int(cfg.get("cond_dim", 32)),
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
@@ -188,7 +194,7 @@ def main():
temporal_model = None
if use_temporal_stage1:
temporal_model = TemporalGRUGenerator(
input_dim=len(cont_cols),
input_dim=len(model_cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_num_layers,
dropout=temporal_dropout,
@@ -203,7 +209,7 @@ def main():
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
x_cont = torch.randn(args.batch_size, args.seq_len, len(cont_cols), device=device)
x_cont = torch.randn(args.batch_size, args.seq_len, len(model_cont_cols), device=device)
x_disc = torch.full(
(args.batch_size, args.seq_len, len(disc_cols)),
0,
@@ -225,13 +231,39 @@ def main():
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
cond = cond_id
# type1 program conditioning (library replay)
cond_cont = None
if type1_cols:
ref_glob = cfg.get("data_glob") or args.data_glob
if ref_glob:
ref_glob = str(resolve_path(Path(args.config).parent, ref_glob)) if args.config else ref_glob
base = Path(ref_glob).parent
pat = Path(ref_glob).name
refs = sorted(base.glob(pat))
if refs:
ref_path = refs[0]
ref_rows = []
with gzip.open(ref_path, "rt", newline="") as fh:
reader = csv.DictReader(fh)
for row in reader:
ref_rows.append(row)
if len(ref_rows) >= args.seq_len:
seq = ref_rows[: args.seq_len]
cond_cont = torch.zeros(args.batch_size, args.seq_len, len(type1_cols), device=device)
for t, row in enumerate(seq):
for i, c in enumerate(type1_cols):
cond_cont[:, t, i] = float(row[c])
mean_vec = torch.tensor([mean[c] for c in type1_cols], dtype=cond_cont.dtype, device=device)
std_vec = torch.tensor([std[c] for c in type1_cols], dtype=cond_cont.dtype, device=device)
cond_cont = (cond_cont - mean_vec) / std_vec
trend = None
if temporal_model is not None:
trend = temporal_model.generate(args.batch_size, args.seq_len, device)
for t in reversed(range(args.timesteps)):
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
eps_pred, logits = model(x_cont, x_disc, t_batch, cond, cond_cont=cond_cont)
a_t = alphas[t]
a_bar_t = alphas_cumprod[t]
@@ -276,19 +308,21 @@ def main():
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
if use_quantile:
x_cont = inverse_quantile_transform(x_cont, cont_cols, quantile_probs, quantile_values)
q_vals = {c: quantile_values[c] for c in model_cont_cols}
x_cont = inverse_quantile_transform(x_cont, model_cont_cols, quantile_probs, q_vals)
else:
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
mean_vec = torch.tensor([mean[c] for c in model_cont_cols], dtype=x_cont.dtype)
std_vec = torch.tensor([std[c] for c in model_cont_cols], dtype=x_cont.dtype)
x_cont = x_cont * std_vec + mean_vec
for i, c in enumerate(cont_cols):
for i, c in enumerate(model_cont_cols):
if transforms.get(c) == "log1p":
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])
if cont_post_calibrate and quantile_raw_values and quantile_probs:
x_cont = quantile_calibrate_to_real(x_cont, cont_cols, quantile_probs, quantile_raw_values)
q_raw = {c: quantile_raw_values[c] for c in model_cont_cols}
x_cont = quantile_calibrate_to_real(x_cont, model_cont_cols, quantile_probs, q_raw)
# bound to observed min/max per feature
if vmin and vmax:
for i, c in enumerate(cont_cols):
for i, c in enumerate(model_cont_cols):
lo = vmin.get(c, None)
hi = vmax.get(c, None)
if lo is None or hi is None:
@@ -310,7 +344,7 @@ def main():
# optional post-scaling for problematic features
if cont_post_scale:
for i, c in enumerate(cont_cols):
for i, c in enumerate(model_cont_cols):
if c in cont_post_scale:
try:
scale = float(cont_post_scale[c])
@@ -318,6 +352,26 @@ def main():
scale = 1.0
x_cont[:, :, i] = x_cont[:, :, i] * scale
# assemble full continuous output
full_cont = torch.zeros(args.batch_size, args.seq_len, len(cont_cols), dtype=x_cont.dtype)
for i, c in enumerate(model_cont_cols):
full_idx = cont_cols.index(c)
full_cont[:, :, full_idx] = x_cont[:, :, i]
if cond_cont is not None and type1_cols:
mean_vec = torch.tensor([mean[c] for c in type1_cols], dtype=cond_cont.dtype)
std_vec = torch.tensor([std[c] for c in type1_cols], dtype=cond_cont.dtype)
cond_denorm = cond_cont.cpu() * std_vec + mean_vec
for i, c in enumerate(type1_cols):
full_idx = cont_cols.index(c)
full_cont[:, :, full_idx] = cond_denorm[:, :, i]
for c in type5_cols:
if c.endswith("Z"):
base = c[:-1]
if base in cont_cols:
bidx = cont_cols.index(base)
cidx = cont_cols.index(c)
full_cont[:, :, cidx] = full_cont[:, :, bidx]
header = read_header(data_path)
out_cols = [c for c in header if c != time_col or args.include_time]
if args.include_condition and use_condition:
@@ -337,7 +391,7 @@ def main():
if args.include_time and time_col in header:
row[time_col] = str(row_index)
for i, c in enumerate(cont_cols):
val = float(x_cont[b, t, i])
val = float(full_cont[b, t, i])
if int_like.get(c, False):
row[c] = str(int(round(val)))
else: