This commit is contained in:
2026-01-24 00:34:28 +08:00
parent 743d6bb857
commit 444ecd856b
4 changed files with 31 additions and 2 deletions

View File

@@ -62,6 +62,8 @@ DEFAULTS = {
"shuffle_buffer": 256,
"cont_loss_weighting": "none", # none | inv_std
"cont_loss_eps": 1e-6,
"cont_target": "eps", # eps | x0
"cont_clamp_x0": 0.0,
}
@@ -249,15 +251,24 @@ def main():
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)
cont_target = str(config.get("cont_target", "eps"))
if cont_target == "x0":
x0_target = x_cont
if float(config.get("cont_clamp_x0", 0.0)) > 0:
x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"]))
loss_base = (eps_pred - x0_target) ** 2
else:
loss_base = (eps_pred - noise) ** 2
if config.get("cont_loss_weighting") == "inv_std":
weights = torch.tensor(
[1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in cont_cols],
device=device,
dtype=eps_pred.dtype,
).view(1, 1, -1)
loss_cont = ((eps_pred - noise) ** 2 * weights).mean()
loss_cont = (loss_base * weights).mean()
else:
loss_cont = F.mse_loss(eps_pred, noise)
loss_cont = loss_base.mean()
loss_disc = 0.0
loss_disc_count = 0
for i, logit in enumerate(logits):