update
This commit is contained in:
@@ -193,14 +193,14 @@ def main():
|
|||||||
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
|
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
|
||||||
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
||||||
|
|
||||||
|
a_t = alphas[t]
|
||||||
|
a_bar_t = alphas_cumprod[t]
|
||||||
|
|
||||||
if cont_target == "x0":
|
if cont_target == "x0":
|
||||||
x0_pred = eps_pred
|
x0_pred = eps_pred
|
||||||
if cont_clamp_x0 > 0:
|
if cont_clamp_x0 > 0:
|
||||||
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
|
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
|
||||||
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
|
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
|
||||||
|
|
||||||
a_t = alphas[t]
|
|
||||||
a_bar_t = alphas_cumprod[t]
|
|
||||||
coef1 = 1.0 / torch.sqrt(a_t)
|
coef1 = 1.0 / torch.sqrt(a_t)
|
||||||
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
|
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
|
||||||
mean_x = coef1 * (x_cont - coef2 * eps_pred)
|
mean_x = coef1 * (x_cont - coef2 * eps_pred)
|
||||||
|
|||||||
Reference in New Issue
Block a user