update ks
This commit is contained in:
@@ -62,7 +62,7 @@ DEFAULTS = {
|
||||
"shuffle_buffer": 256,
|
||||
"cont_loss_weighting": "none", # none | inv_std
|
||||
"cont_loss_eps": 1e-6,
|
||||
"cont_target": "eps", # eps | x0
|
||||
"cont_target": "eps", # eps | x0 | v
|
||||
"cont_clamp_x0": 0.0,
|
||||
"quantile_loss_weight": 0.0,
|
||||
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
|
||||
@@ -259,6 +259,10 @@ def main():
|
||||
if float(config.get("cont_clamp_x0", 0.0)) > 0:
|
||||
x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"]))
|
||||
loss_base = (eps_pred - x0_target) ** 2
|
||||
elif cont_target == "v":
|
||||
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
|
||||
v_target = torch.sqrt(a_bar_t) * noise - torch.sqrt(1.0 - a_bar_t) * x_cont
|
||||
loss_base = (eps_pred - v_target) ** 2
|
||||
else:
|
||||
loss_base = (eps_pred - noise) ** 2
|
||||
|
||||
|
||||
Reference in New Issue
Block a user