diff --git a/example/config.json b/example/config.json index 58baa92..fab6320 100644 --- a/example/config.json +++ b/example/config.json @@ -11,7 +11,7 @@ "seq_len": 128, "epochs": 10, "max_batches": 4000, - "lambda": 0.5, + "lambda": 0.7, "lr": 0.0005, "seed": 1337, "log_every": 10, @@ -33,6 +33,8 @@ "model_pos_dim": 64, "model_use_pos_embed": true, "disc_mask_scale": 0.9, + "cont_loss_weighting": "inv_std", + "cont_loss_eps": 1e-6, "shuffle_buffer": 256, "sample_batch_size": 8, "sample_seq_len": 128 diff --git a/example/train.py b/example/train.py index 467eb76..516f158 100755 --- a/example/train.py +++ b/example/train.py @@ -60,6 +60,8 @@ DEFAULTS = { "model_use_pos_embed": True, "disc_mask_scale": 0.9, "shuffle_buffer": 256, + "cont_loss_weighting": "none", # none | inv_std + "cont_loss_eps": 1e-6, } @@ -131,6 +133,8 @@ class EMA: def main(): args = parse_args() + if args.config: + print("using_config", str(Path(args.config).resolve())) config = dict(DEFAULTS) if args.config: cfg_path = Path(args.config).resolve() @@ -154,6 +158,7 @@ def main(): mean = stats["mean"] std = stats["std"] transforms = stats.get("transform", {}) + raw_std = stats.get("raw_std", std) vocab = load_json(config["vocab_path"])["vocab"] vocab_sizes = [len(vocab[c]) for c in disc_cols] @@ -244,7 +249,15 @@ def main(): eps_pred, logits = model(x_cont_t, x_disc_t, t, cond) - loss_cont = F.mse_loss(eps_pred, noise) + if config.get("cont_loss_weighting") == "inv_std": + weights = torch.tensor( + [1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in cont_cols], + device=device, + dtype=eps_pred.dtype, + ).view(1, 1, -1) + loss_cont = ((eps_pred - noise) ** 2 * weights).mean() + else: + loss_cont = F.mse_loss(eps_pred, noise) loss_disc = 0.0 loss_disc_count = 0 for i, logit in enumerate(logits):