update
This commit is contained in:
@@ -11,7 +11,7 @@
|
|||||||
"seq_len": 128,
|
"seq_len": 128,
|
||||||
"epochs": 10,
|
"epochs": 10,
|
||||||
"max_batches": 4000,
|
"max_batches": 4000,
|
||||||
"lambda": 0.5,
|
"lambda": 0.7,
|
||||||
"lr": 0.0005,
|
"lr": 0.0005,
|
||||||
"seed": 1337,
|
"seed": 1337,
|
||||||
"log_every": 10,
|
"log_every": 10,
|
||||||
@@ -33,6 +33,8 @@
|
|||||||
"model_pos_dim": 64,
|
"model_pos_dim": 64,
|
||||||
"model_use_pos_embed": true,
|
"model_use_pos_embed": true,
|
||||||
"disc_mask_scale": 0.9,
|
"disc_mask_scale": 0.9,
|
||||||
|
"cont_loss_weighting": "inv_std",
|
||||||
|
"cont_loss_eps": 1e-6,
|
||||||
"shuffle_buffer": 256,
|
"shuffle_buffer": 256,
|
||||||
"sample_batch_size": 8,
|
"sample_batch_size": 8,
|
||||||
"sample_seq_len": 128
|
"sample_seq_len": 128
|
||||||
|
|||||||
@@ -60,6 +60,8 @@ DEFAULTS = {
|
|||||||
"model_use_pos_embed": True,
|
"model_use_pos_embed": True,
|
||||||
"disc_mask_scale": 0.9,
|
"disc_mask_scale": 0.9,
|
||||||
"shuffle_buffer": 256,
|
"shuffle_buffer": 256,
|
||||||
|
"cont_loss_weighting": "none", # none | inv_std
|
||||||
|
"cont_loss_eps": 1e-6,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -131,6 +133,8 @@ class EMA:
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
if args.config:
|
||||||
|
print("using_config", str(Path(args.config).resolve()))
|
||||||
config = dict(DEFAULTS)
|
config = dict(DEFAULTS)
|
||||||
if args.config:
|
if args.config:
|
||||||
cfg_path = Path(args.config).resolve()
|
cfg_path = Path(args.config).resolve()
|
||||||
@@ -154,6 +158,7 @@ def main():
|
|||||||
mean = stats["mean"]
|
mean = stats["mean"]
|
||||||
std = stats["std"]
|
std = stats["std"]
|
||||||
transforms = stats.get("transform", {})
|
transforms = stats.get("transform", {})
|
||||||
|
raw_std = stats.get("raw_std", std)
|
||||||
|
|
||||||
vocab = load_json(config["vocab_path"])["vocab"]
|
vocab = load_json(config["vocab_path"])["vocab"]
|
||||||
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||||
@@ -244,7 +249,15 @@ def main():
|
|||||||
|
|
||||||
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)
|
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)
|
||||||
|
|
||||||
loss_cont = F.mse_loss(eps_pred, noise)
|
if config.get("cont_loss_weighting") == "inv_std":
|
||||||
|
weights = torch.tensor(
|
||||||
|
[1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in cont_cols],
|
||||||
|
device=device,
|
||||||
|
dtype=eps_pred.dtype,
|
||||||
|
).view(1, 1, -1)
|
||||||
|
loss_cont = ((eps_pred - noise) ** 2 * weights).mean()
|
||||||
|
else:
|
||||||
|
loss_cont = F.mse_loss(eps_pred, noise)
|
||||||
loss_disc = 0.0
|
loss_disc = 0.0
|
||||||
loss_disc_count = 0
|
loss_disc_count = 0
|
||||||
for i, logit in enumerate(logits):
|
for i, logit in enumerate(logits):
|
||||||
|
|||||||
Reference in New Issue
Block a user