update
This commit is contained in:
@@ -47,6 +47,8 @@ def main():
|
||||
cond_dim = int(cfg.get("cond_dim", 32))
|
||||
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
|
||||
eps_scale = float(cfg.get("eps_scale", 1.0))
|
||||
cont_target = str(cfg.get("cont_target", "eps"))
|
||||
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
|
||||
model_time_dim = int(cfg.get("model_time_dim", 64))
|
||||
model_hidden_dim = int(cfg.get("model_hidden_dim", 256))
|
||||
model_num_layers = int(cfg.get("model_num_layers", 1))
|
||||
@@ -112,6 +114,12 @@ def main():
|
||||
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
|
||||
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
||||
|
||||
if cont_target == "x0":
|
||||
x0_pred = eps_pred
|
||||
if cont_clamp_x0 > 0:
|
||||
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
|
||||
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
|
||||
|
||||
# Continuous reverse step (DDPM): x_{t-1} mean
|
||||
a_t = alphas[t]
|
||||
a_bar_t = alphas_cumprod[t]
|
||||
|
||||
Reference in New Issue
Block a user