This commit is contained in:
2026-01-26 22:56:34 +08:00
parent 311164c22d
commit c731adeea6
2 changed files with 22 additions and 0 deletions

View File

@@ -71,6 +71,8 @@ DEFAULTS = {
"temporal_dropout": 0.0,
"temporal_epochs": 2,
"temporal_lr": 1e-3,
"quantile_loss_weight": 0.0,
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
}
@@ -340,6 +342,24 @@ def main():
lam = float(config["lambda"])
loss = lam * loss_cont + (1 - lam) * loss_disc
q_weight = float(config.get("quantile_loss_weight", 0.0))
if q_weight > 0:
q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95])
q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype)
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
# Quantile loss on residual distribution
x_real = x_cont_resid
if cont_target == "x0":
x_gen = eps_pred
else:
x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t)
x_real = x_real.view(-1, x_real.size(-1))
x_gen = x_gen.view(-1, x_gen.size(-1))
q_real = torch.quantile(x_real, q_tensor, dim=0)
q_gen = torch.quantile(x_gen, q_tensor, dim=0)
quantile_loss = torch.mean(torch.abs(q_gen - q_real))
loss = loss + q_weight * quantile_loss
opt.zero_grad()
loss.backward()
if float(config.get("grad_clip", 0.0)) > 0: