优化6个类,现在ks降低到0.28,史称3.0版本
This commit is contained in:
@@ -168,6 +168,14 @@ def main():
|
||||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
||||
|
||||
type1_cols = config.get("type1_features", []) or []
|
||||
type5_cols = config.get("type5_features", []) or []
|
||||
type1_cols = [c for c in type1_cols if c in cont_cols]
|
||||
type5_cols = [c for c in type5_cols if c in cont_cols]
|
||||
model_cont_cols = [c for c in cont_cols if c not in type1_cols and c not in type5_cols]
|
||||
if not model_cont_cols:
|
||||
raise SystemExit("model_cont_cols is empty; check type1/type5 config")
|
||||
|
||||
stats = load_json(config["stats_path"])
|
||||
mean = stats["mean"]
|
||||
std = stats["std"]
|
||||
@@ -194,7 +202,7 @@ def main():
|
||||
device = resolve_device(str(config["device"]))
|
||||
print("device", device)
|
||||
model = HybridDiffusionModel(
|
||||
cont_dim=len(cont_cols),
|
||||
cont_dim=len(model_cont_cols),
|
||||
disc_vocab_sizes=vocab_sizes,
|
||||
time_dim=int(config.get("model_time_dim", 64)),
|
||||
hidden_dim=int(config.get("model_hidden_dim", 256)),
|
||||
@@ -208,6 +216,7 @@ def main():
|
||||
transformer_nhead=int(config.get("transformer_nhead", 8)),
|
||||
transformer_ff_dim=int(config.get("transformer_ff_dim", 2048)),
|
||||
transformer_dropout=float(config.get("transformer_dropout", 0.1)),
|
||||
cond_cont_dim=len(type1_cols),
|
||||
cond_vocab_size=cond_vocab_size,
|
||||
cond_dim=int(config.get("cond_dim", 32)),
|
||||
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
|
||||
@@ -218,7 +227,7 @@ def main():
|
||||
opt_temporal = None
|
||||
if bool(config.get("use_temporal_stage1", False)):
|
||||
temporal_model = TemporalGRUGenerator(
|
||||
input_dim=len(cont_cols),
|
||||
input_dim=len(model_cont_cols),
|
||||
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
|
||||
num_layers=int(config.get("temporal_num_layers", 1)),
|
||||
dropout=float(config.get("temporal_dropout", 0.0)),
|
||||
@@ -264,8 +273,10 @@ def main():
|
||||
):
|
||||
x_cont, _ = batch
|
||||
x_cont = x_cont.to(device)
|
||||
trend, pred_next = temporal_model.forward_teacher(x_cont)
|
||||
temporal_loss = F.mse_loss(pred_next, x_cont[:, 1:, :])
|
||||
model_idx = [cont_cols.index(c) for c in model_cont_cols]
|
||||
x_cont_model = x_cont[:, :, model_idx]
|
||||
trend, pred_next = temporal_model.forward_teacher(x_cont_model)
|
||||
temporal_loss = F.mse_loss(pred_next, x_cont_model[:, 1:, :])
|
||||
opt_temporal.zero_grad()
|
||||
temporal_loss.backward()
|
||||
if float(config.get("grad_clip", 0.0)) > 0:
|
||||
@@ -305,12 +316,17 @@ def main():
|
||||
x_cont = x_cont.to(device)
|
||||
x_disc = x_disc.to(device)
|
||||
|
||||
model_idx = [cont_cols.index(c) for c in model_cont_cols]
|
||||
cond_idx = [cont_cols.index(c) for c in type1_cols] if type1_cols else []
|
||||
x_cont_model = x_cont[:, :, model_idx]
|
||||
cond_cont = x_cont[:, :, cond_idx] if cond_idx else None
|
||||
|
||||
trend = None
|
||||
if temporal_model is not None:
|
||||
temporal_model.eval()
|
||||
with torch.no_grad():
|
||||
trend, _ = temporal_model.forward_teacher(x_cont)
|
||||
x_cont_resid = x_cont if trend is None else x_cont - trend
|
||||
trend, _ = temporal_model.forward_teacher(x_cont_model)
|
||||
x_cont_resid = x_cont_model if trend is None else x_cont_model - trend
|
||||
|
||||
bsz = x_cont.size(0)
|
||||
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
|
||||
@@ -326,7 +342,7 @@ def main():
|
||||
mask_scale=float(config.get("disc_mask_scale", 1.0)),
|
||||
)
|
||||
|
||||
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)
|
||||
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond, cond_cont=cond_cont)
|
||||
|
||||
cont_target = str(config.get("cont_target", "eps"))
|
||||
if cont_target == "x0":
|
||||
|
||||
Reference in New Issue
Block a user