优化6个类,现在ks降低到0.28,史称3.0版本

This commit is contained in:
2026-01-28 20:10:42 +08:00
parent 59697c0640
commit 39eede92f6
28 changed files with 3317 additions and 225 deletions

View File

@@ -168,6 +168,14 @@ def main():
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
type1_cols = config.get("type1_features", []) or []
type5_cols = config.get("type5_features", []) or []
type1_cols = [c for c in type1_cols if c in cont_cols]
type5_cols = [c for c in type5_cols if c in cont_cols]
model_cont_cols = [c for c in cont_cols if c not in type1_cols and c not in type5_cols]
if not model_cont_cols:
raise SystemExit("model_cont_cols is empty; check type1/type5 config")
stats = load_json(config["stats_path"])
mean = stats["mean"]
std = stats["std"]
@@ -194,7 +202,7 @@ def main():
device = resolve_device(str(config["device"]))
print("device", device)
model = HybridDiffusionModel(
cont_dim=len(cont_cols),
cont_dim=len(model_cont_cols),
disc_vocab_sizes=vocab_sizes,
time_dim=int(config.get("model_time_dim", 64)),
hidden_dim=int(config.get("model_hidden_dim", 256)),
@@ -208,6 +216,7 @@ def main():
transformer_nhead=int(config.get("transformer_nhead", 8)),
transformer_ff_dim=int(config.get("transformer_ff_dim", 2048)),
transformer_dropout=float(config.get("transformer_dropout", 0.1)),
cond_cont_dim=len(type1_cols),
cond_vocab_size=cond_vocab_size,
cond_dim=int(config.get("cond_dim", 32)),
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
@@ -218,7 +227,7 @@ def main():
opt_temporal = None
if bool(config.get("use_temporal_stage1", False)):
temporal_model = TemporalGRUGenerator(
input_dim=len(cont_cols),
input_dim=len(model_cont_cols),
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
num_layers=int(config.get("temporal_num_layers", 1)),
dropout=float(config.get("temporal_dropout", 0.0)),
@@ -264,8 +273,10 @@ def main():
):
x_cont, _ = batch
x_cont = x_cont.to(device)
trend, pred_next = temporal_model.forward_teacher(x_cont)
temporal_loss = F.mse_loss(pred_next, x_cont[:, 1:, :])
model_idx = [cont_cols.index(c) for c in model_cont_cols]
x_cont_model = x_cont[:, :, model_idx]
trend, pred_next = temporal_model.forward_teacher(x_cont_model)
temporal_loss = F.mse_loss(pred_next, x_cont_model[:, 1:, :])
opt_temporal.zero_grad()
temporal_loss.backward()
if float(config.get("grad_clip", 0.0)) > 0:
@@ -305,12 +316,17 @@ def main():
x_cont = x_cont.to(device)
x_disc = x_disc.to(device)
model_idx = [cont_cols.index(c) for c in model_cont_cols]
cond_idx = [cont_cols.index(c) for c in type1_cols] if type1_cols else []
x_cont_model = x_cont[:, :, model_idx]
cond_cont = x_cont[:, :, cond_idx] if cond_idx else None
trend = None
if temporal_model is not None:
temporal_model.eval()
with torch.no_grad():
trend, _ = temporal_model.forward_teacher(x_cont)
x_cont_resid = x_cont if trend is None else x_cont - trend
trend, _ = temporal_model.forward_teacher(x_cont_model)
x_cont_resid = x_cont_model if trend is None else x_cont_model - trend
bsz = x_cont.size(0)
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
@@ -326,7 +342,7 @@ def main():
mask_scale=float(config.get("disc_mask_scale", 1.0)),
)
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond, cond_cont=cond_cont)
cont_target = str(config.get("cont_target", "eps"))
if cont_target == "x0":