优化6个类,现在ks降低到0.28,史称3.0版本
This commit is contained in:
@@ -148,6 +148,11 @@ def main():
|
||||
cont_bound_strength = float(cfg.get("cont_bound_strength", 1.0))
|
||||
cont_post_scale = cfg.get("cont_post_scale", {}) if isinstance(cfg.get("cont_post_scale", {}), dict) else {}
|
||||
cont_post_calibrate = bool(cfg.get("cont_post_calibrate", False))
|
||||
type1_cols = cfg.get("type1_features", []) or []
|
||||
type5_cols = cfg.get("type5_features", []) or []
|
||||
type1_cols = [c for c in type1_cols if c in cont_cols]
|
||||
type5_cols = [c for c in type5_cols if c in cont_cols]
|
||||
model_cont_cols = [c for c in cont_cols if c not in type1_cols and c not in type5_cols]
|
||||
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
|
||||
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
|
||||
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
|
||||
@@ -159,7 +164,7 @@ def main():
|
||||
transformer_dropout = float(cfg.get("transformer_dropout", 0.1))
|
||||
|
||||
model = HybridDiffusionModel(
|
||||
cont_dim=len(cont_cols),
|
||||
cont_dim=len(model_cont_cols),
|
||||
disc_vocab_sizes=vocab_sizes,
|
||||
time_dim=int(cfg.get("model_time_dim", 64)),
|
||||
hidden_dim=int(cfg.get("model_hidden_dim", 256)),
|
||||
@@ -173,6 +178,7 @@ def main():
|
||||
transformer_nhead=transformer_nhead,
|
||||
transformer_ff_dim=transformer_ff_dim,
|
||||
transformer_dropout=transformer_dropout,
|
||||
cond_cont_dim=len(type1_cols),
|
||||
cond_vocab_size=cond_vocab_size if use_condition else 0,
|
||||
cond_dim=int(cfg.get("cond_dim", 32)),
|
||||
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
|
||||
@@ -188,7 +194,7 @@ def main():
|
||||
temporal_model = None
|
||||
if use_temporal_stage1:
|
||||
temporal_model = TemporalGRUGenerator(
|
||||
input_dim=len(cont_cols),
|
||||
input_dim=len(model_cont_cols),
|
||||
hidden_dim=temporal_hidden_dim,
|
||||
num_layers=temporal_num_layers,
|
||||
dropout=temporal_dropout,
|
||||
@@ -203,7 +209,7 @@ def main():
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
|
||||
x_cont = torch.randn(args.batch_size, args.seq_len, len(cont_cols), device=device)
|
||||
x_cont = torch.randn(args.batch_size, args.seq_len, len(model_cont_cols), device=device)
|
||||
x_disc = torch.full(
|
||||
(args.batch_size, args.seq_len, len(disc_cols)),
|
||||
0,
|
||||
@@ -225,13 +231,39 @@ def main():
|
||||
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
|
||||
cond = cond_id
|
||||
|
||||
# type1 program conditioning (library replay)
|
||||
cond_cont = None
|
||||
if type1_cols:
|
||||
ref_glob = cfg.get("data_glob") or args.data_glob
|
||||
if ref_glob:
|
||||
ref_glob = str(resolve_path(Path(args.config).parent, ref_glob)) if args.config else ref_glob
|
||||
base = Path(ref_glob).parent
|
||||
pat = Path(ref_glob).name
|
||||
refs = sorted(base.glob(pat))
|
||||
if refs:
|
||||
ref_path = refs[0]
|
||||
ref_rows = []
|
||||
with gzip.open(ref_path, "rt", newline="") as fh:
|
||||
reader = csv.DictReader(fh)
|
||||
for row in reader:
|
||||
ref_rows.append(row)
|
||||
if len(ref_rows) >= args.seq_len:
|
||||
seq = ref_rows[: args.seq_len]
|
||||
cond_cont = torch.zeros(args.batch_size, args.seq_len, len(type1_cols), device=device)
|
||||
for t, row in enumerate(seq):
|
||||
for i, c in enumerate(type1_cols):
|
||||
cond_cont[:, t, i] = float(row[c])
|
||||
mean_vec = torch.tensor([mean[c] for c in type1_cols], dtype=cond_cont.dtype, device=device)
|
||||
std_vec = torch.tensor([std[c] for c in type1_cols], dtype=cond_cont.dtype, device=device)
|
||||
cond_cont = (cond_cont - mean_vec) / std_vec
|
||||
|
||||
trend = None
|
||||
if temporal_model is not None:
|
||||
trend = temporal_model.generate(args.batch_size, args.seq_len, device)
|
||||
|
||||
for t in reversed(range(args.timesteps)):
|
||||
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
|
||||
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
||||
eps_pred, logits = model(x_cont, x_disc, t_batch, cond, cond_cont=cond_cont)
|
||||
|
||||
a_t = alphas[t]
|
||||
a_bar_t = alphas_cumprod[t]
|
||||
@@ -276,19 +308,21 @@ def main():
|
||||
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
|
||||
|
||||
if use_quantile:
|
||||
x_cont = inverse_quantile_transform(x_cont, cont_cols, quantile_probs, quantile_values)
|
||||
q_vals = {c: quantile_values[c] for c in model_cont_cols}
|
||||
x_cont = inverse_quantile_transform(x_cont, model_cont_cols, quantile_probs, q_vals)
|
||||
else:
|
||||
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||
mean_vec = torch.tensor([mean[c] for c in model_cont_cols], dtype=x_cont.dtype)
|
||||
std_vec = torch.tensor([std[c] for c in model_cont_cols], dtype=x_cont.dtype)
|
||||
x_cont = x_cont * std_vec + mean_vec
|
||||
for i, c in enumerate(cont_cols):
|
||||
for i, c in enumerate(model_cont_cols):
|
||||
if transforms.get(c) == "log1p":
|
||||
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])
|
||||
if cont_post_calibrate and quantile_raw_values and quantile_probs:
|
||||
x_cont = quantile_calibrate_to_real(x_cont, cont_cols, quantile_probs, quantile_raw_values)
|
||||
q_raw = {c: quantile_raw_values[c] for c in model_cont_cols}
|
||||
x_cont = quantile_calibrate_to_real(x_cont, model_cont_cols, quantile_probs, q_raw)
|
||||
# bound to observed min/max per feature
|
||||
if vmin and vmax:
|
||||
for i, c in enumerate(cont_cols):
|
||||
for i, c in enumerate(model_cont_cols):
|
||||
lo = vmin.get(c, None)
|
||||
hi = vmax.get(c, None)
|
||||
if lo is None or hi is None:
|
||||
@@ -310,7 +344,7 @@ def main():
|
||||
|
||||
# optional post-scaling for problematic features
|
||||
if cont_post_scale:
|
||||
for i, c in enumerate(cont_cols):
|
||||
for i, c in enumerate(model_cont_cols):
|
||||
if c in cont_post_scale:
|
||||
try:
|
||||
scale = float(cont_post_scale[c])
|
||||
@@ -318,6 +352,26 @@ def main():
|
||||
scale = 1.0
|
||||
x_cont[:, :, i] = x_cont[:, :, i] * scale
|
||||
|
||||
# assemble full continuous output
|
||||
full_cont = torch.zeros(args.batch_size, args.seq_len, len(cont_cols), dtype=x_cont.dtype)
|
||||
for i, c in enumerate(model_cont_cols):
|
||||
full_idx = cont_cols.index(c)
|
||||
full_cont[:, :, full_idx] = x_cont[:, :, i]
|
||||
if cond_cont is not None and type1_cols:
|
||||
mean_vec = torch.tensor([mean[c] for c in type1_cols], dtype=cond_cont.dtype)
|
||||
std_vec = torch.tensor([std[c] for c in type1_cols], dtype=cond_cont.dtype)
|
||||
cond_denorm = cond_cont.cpu() * std_vec + mean_vec
|
||||
for i, c in enumerate(type1_cols):
|
||||
full_idx = cont_cols.index(c)
|
||||
full_cont[:, :, full_idx] = cond_denorm[:, :, i]
|
||||
for c in type5_cols:
|
||||
if c.endswith("Z"):
|
||||
base = c[:-1]
|
||||
if base in cont_cols:
|
||||
bidx = cont_cols.index(base)
|
||||
cidx = cont_cols.index(c)
|
||||
full_cont[:, :, cidx] = full_cont[:, :, bidx]
|
||||
|
||||
header = read_header(data_path)
|
||||
out_cols = [c for c in header if c != time_col or args.include_time]
|
||||
if args.include_condition and use_condition:
|
||||
@@ -337,7 +391,7 @@ def main():
|
||||
if args.include_time and time_col in header:
|
||||
row[time_col] = str(row_index)
|
||||
for i, c in enumerate(cont_cols):
|
||||
val = float(x_cont[b, t, i])
|
||||
val = float(full_cont[b, t, i])
|
||||
if int_like.get(c, False):
|
||||
row[c] = str(int(round(val)))
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user