This commit is contained in:
2026-01-23 12:00:29 +08:00
parent 97e47be051
commit 0f74156460
5 changed files with 22 additions and 4 deletions

View File

@@ -179,14 +179,20 @@ def main():
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
cond = cond_id
cont_pred = str(cfg.get("cont_pred", "eps")).lower()
for t in reversed(range(args.timesteps)):
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
cont_pred_out, logits = model(x_cont, x_disc, t_batch, cond)
a_t = alphas[t]
a_bar_t = alphas_cumprod[t]
coef1 = 1.0 / torch.sqrt(a_t)
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
if cont_pred == "x0":
# eps = (x_t - sqrt(a_bar) * x0) / sqrt(1 - a_bar)
eps_pred = (x_cont - torch.sqrt(a_bar_t) * cont_pred_out) / torch.sqrt(1 - a_bar_t + 1e-8)
else:
eps_pred = cont_pred_out
mean_x = coef1 * (x_cont - coef2 * eps_pred)
if t > 0:
noise = torch.randn_like(x_cont)