This commit is contained in:
2026-01-23 12:00:29 +08:00
parent 97e47be051
commit 0f74156460
5 changed files with 22 additions and 4 deletions

View File

@@ -51,6 +51,7 @@ DEFAULTS = {
"cond_dim": 32,
"use_tanh_eps": True,
"eps_scale": 1.0,
"cont_pred": "eps",
}
@@ -161,6 +162,7 @@ def main():
device = resolve_device(str(config["device"]))
print("device", device)
cont_pred = str(config.get("cont_pred", "eps")).lower()
model = HybridDiffusionModel(
cont_dim=len(cont_cols),
disc_vocab_sizes=vocab_sizes,
@@ -217,9 +219,12 @@ def main():
mask_tokens = torch.tensor(vocab_sizes, device=device)
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"]))
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)
cont_pred_out, logits = model(x_cont_t, x_disc_t, t, cond)
loss_cont = F.mse_loss(eps_pred, noise)
if cont_pred == "x0":
loss_cont = F.mse_loss(cont_pred_out, x_cont)
else:
loss_cont = F.mse_loss(cont_pred_out, noise)
loss_disc = 0.0
loss_disc_count = 0
for i, logit in enumerate(logits):