update ks

This commit is contained in:
2026-01-25 17:55:28 +08:00
parent 5ef1e465f9
commit cc10125fbf
4 changed files with 18 additions and 7 deletions

View File

@@ -201,6 +201,10 @@ def main():
if cont_clamp_x0 > 0:
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
elif cont_target == "v":
v_pred = eps_pred
x0_pred = torch.sqrt(a_bar_t) * x_cont - torch.sqrt(1.0 - a_bar_t) * v_pred
eps_pred = torch.sqrt(1.0 - a_bar_t) * x_cont + torch.sqrt(a_bar_t) * v_pred
coef1 = 1.0 / torch.sqrt(a_t)
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
mean_x = coef1 * (x_cont - coef2 * eps_pred)