diff --git a/example/config.json b/example/config.json index fab6320..eefcd91 100644 --- a/example/config.json +++ b/example/config.json @@ -35,6 +35,8 @@ "disc_mask_scale": 0.9, "cont_loss_weighting": "inv_std", "cont_loss_eps": 1e-6, + "cont_target": "x0", + "cont_clamp_x0": 5.0, "shuffle_buffer": 256, "sample_batch_size": 8, "sample_seq_len": 128 diff --git a/example/export_samples.py b/example/export_samples.py index 39da225..0b4d5df 100644 --- a/example/export_samples.py +++ b/example/export_samples.py @@ -112,6 +112,8 @@ def main(): int_like = stats.get("int_like", {}) max_decimals = stats.get("max_decimals", {}) transforms = stats.get("transform", {}) + cont_target = str(cfg.get("cont_target", "eps")) + cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0)) vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8")) vocab = vocab_json["vocab"] @@ -191,6 +193,12 @@ def main(): t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long) eps_pred, logits = model(x_cont, x_disc, t_batch, cond) + if cont_target == "x0": + x0_pred = eps_pred + if cont_clamp_x0 > 0: + x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0) + eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t) + a_t = alphas[t] a_bar_t = alphas_cumprod[t] coef1 = 1.0 / torch.sqrt(a_t) diff --git a/example/sample.py b/example/sample.py index 1b5f85d..9f4a6e1 100755 --- a/example/sample.py +++ b/example/sample.py @@ -47,6 +47,8 @@ def main(): cond_dim = int(cfg.get("cond_dim", 32)) use_tanh_eps = bool(cfg.get("use_tanh_eps", False)) eps_scale = float(cfg.get("eps_scale", 1.0)) + cont_target = str(cfg.get("cont_target", "eps")) + cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0)) model_time_dim = int(cfg.get("model_time_dim", 64)) model_hidden_dim = int(cfg.get("model_hidden_dim", 256)) model_num_layers = int(cfg.get("model_num_layers", 1)) @@ -112,6 +114,12 @@ def main(): t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long) eps_pred, logits = model(x_cont, x_disc, t_batch, cond) + if cont_target == "x0": + x0_pred = eps_pred + if cont_clamp_x0 > 0: + x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0) + eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t) + # Continuous reverse step (DDPM): x_{t-1} mean a_t = alphas[t] a_bar_t = alphas_cumprod[t] diff --git a/example/train.py b/example/train.py index 516f158..a348f57 100755 --- a/example/train.py +++ b/example/train.py @@ -62,6 +62,8 @@ DEFAULTS = { "shuffle_buffer": 256, "cont_loss_weighting": "none", # none | inv_std "cont_loss_eps": 1e-6, + "cont_target": "eps", # eps | x0 + "cont_clamp_x0": 0.0, } @@ -249,15 +251,24 @@ def main(): eps_pred, logits = model(x_cont_t, x_disc_t, t, cond) + cont_target = str(config.get("cont_target", "eps")) + if cont_target == "x0": + x0_target = x_cont + if float(config.get("cont_clamp_x0", 0.0)) > 0: + x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"])) + loss_base = (eps_pred - x0_target) ** 2 + else: + loss_base = (eps_pred - noise) ** 2 + if config.get("cont_loss_weighting") == "inv_std": weights = torch.tensor( [1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in cont_cols], device=device, dtype=eps_pred.dtype, ).view(1, 1, -1) - loss_cont = ((eps_pred - noise) ** 2 * weights).mean() + loss_cont = (loss_base * weights).mean() else: - loss_cont = F.mse_loss(eps_pred, noise) + loss_cont = loss_base.mean() loss_disc = 0.0 loss_disc_count = 0 for i, logit in enumerate(logits):