update
This commit is contained in:
@@ -44,6 +44,7 @@ def main():
|
||||
batch_size = int(cfg.get("sample_batch_size", cfg.get("batch_size", BATCH_SIZE)))
|
||||
clip_k = float(cfg.get("clip_k", CLIP_K))
|
||||
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
|
||||
cont_pred = str(cfg.get("cont_pred", "eps")).lower()
|
||||
cond_dim = int(cfg.get("cond_dim", 32))
|
||||
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
|
||||
eps_scale = float(cfg.get("eps_scale", 1.0))
|
||||
@@ -96,13 +97,17 @@ def main():
|
||||
|
||||
for t in reversed(range(timesteps)):
|
||||
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
|
||||
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
||||
cont_pred_out, logits = model(x_cont, x_disc, t_batch, cond)
|
||||
|
||||
# Continuous reverse step (DDPM): x_{t-1} mean
|
||||
a_t = alphas[t]
|
||||
a_bar_t = alphas_cumprod[t]
|
||||
coef1 = 1.0 / torch.sqrt(a_t)
|
||||
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
|
||||
if cont_pred == "x0":
|
||||
eps_pred = (x_cont - torch.sqrt(a_bar_t) * cont_pred_out) / torch.sqrt(1 - a_bar_t + 1e-8)
|
||||
else:
|
||||
eps_pred = cont_pred_out
|
||||
mean = coef1 * (x_cont - coef2 * eps_pred)
|
||||
if t > 0:
|
||||
noise = torch.randn_like(x_cont)
|
||||
|
||||
Reference in New Issue
Block a user