diff --git a/example/config.json b/example/config.json index 36e6bd8..adda1ac 100644 --- a/example/config.json +++ b/example/config.json @@ -44,6 +44,8 @@ "temporal_dropout": 0.0, "temporal_epochs": 2, "temporal_lr": 0.001, + "quantile_loss_weight": 0.1, + "quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95], "sample_batch_size": 8, "sample_seq_len": 128 } diff --git a/example/train.py b/example/train.py index 7281991..9e0fbad 100755 --- a/example/train.py +++ b/example/train.py @@ -71,6 +71,8 @@ DEFAULTS = { "temporal_dropout": 0.0, "temporal_epochs": 2, "temporal_lr": 1e-3, + "quantile_loss_weight": 0.0, + "quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95], } @@ -340,6 +342,24 @@ def main(): lam = float(config["lambda"]) loss = lam * loss_cont + (1 - lam) * loss_disc + + q_weight = float(config.get("quantile_loss_weight", 0.0)) + if q_weight > 0: + q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95]) + q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype) + a_bar_t = alphas_cumprod[t].view(-1, 1, 1) + # Quantile loss on residual distribution + x_real = x_cont_resid + if cont_target == "x0": + x_gen = eps_pred + else: + x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t) + x_real = x_real.view(-1, x_real.size(-1)) + x_gen = x_gen.view(-1, x_gen.size(-1)) + q_real = torch.quantile(x_real, q_tensor, dim=0) + q_gen = torch.quantile(x_gen, q_tensor, dim=0) + quantile_loss = torch.mean(torch.abs(q_gen - q_real)) + loss = loss + q_weight * quantile_loss opt.zero_grad() loss.backward() if float(config.get("grad_clip", 0.0)) > 0: