update
This commit is contained in:
@@ -51,6 +51,7 @@ DEFAULTS = {
|
||||
"cond_dim": 32,
|
||||
"use_tanh_eps": True,
|
||||
"eps_scale": 1.0,
|
||||
"cont_pred": "eps",
|
||||
}
|
||||
|
||||
|
||||
@@ -161,6 +162,7 @@ def main():
|
||||
|
||||
device = resolve_device(str(config["device"]))
|
||||
print("device", device)
|
||||
cont_pred = str(config.get("cont_pred", "eps")).lower()
|
||||
model = HybridDiffusionModel(
|
||||
cont_dim=len(cont_cols),
|
||||
disc_vocab_sizes=vocab_sizes,
|
||||
@@ -217,9 +219,12 @@ def main():
|
||||
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
||||
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"]))
|
||||
|
||||
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)
|
||||
cont_pred_out, logits = model(x_cont_t, x_disc_t, t, cond)
|
||||
|
||||
loss_cont = F.mse_loss(eps_pred, noise)
|
||||
if cont_pred == "x0":
|
||||
loss_cont = F.mse_loss(cont_pred_out, x_cont)
|
||||
else:
|
||||
loss_cont = F.mse_loss(cont_pred_out, noise)
|
||||
loss_disc = 0.0
|
||||
loss_disc_count = 0
|
||||
for i, logit in enumerate(logits):
|
||||
|
||||
Reference in New Issue
Block a user