This commit is contained in:
2026-01-23 23:48:14 +08:00
parent ec39f7774b
commit 0caa80b6ef
2 changed files with 17 additions and 2 deletions

View File

@@ -60,6 +60,8 @@ DEFAULTS = {
"model_use_pos_embed": True,
"disc_mask_scale": 0.9,
"shuffle_buffer": 256,
"cont_loss_weighting": "none", # none | inv_std
"cont_loss_eps": 1e-6,
}
@@ -131,6 +133,8 @@ class EMA:
def main():
args = parse_args()
if args.config:
print("using_config", str(Path(args.config).resolve()))
config = dict(DEFAULTS)
if args.config:
cfg_path = Path(args.config).resolve()
@@ -154,6 +158,7 @@ def main():
mean = stats["mean"]
std = stats["std"]
transforms = stats.get("transform", {})
raw_std = stats.get("raw_std", std)
vocab = load_json(config["vocab_path"])["vocab"]
vocab_sizes = [len(vocab[c]) for c in disc_cols]
@@ -244,7 +249,15 @@ def main():
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)
loss_cont = F.mse_loss(eps_pred, noise)
if config.get("cont_loss_weighting") == "inv_std":
weights = torch.tensor(
[1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in cont_cols],
device=device,
dtype=eps_pred.dtype,
).view(1, 1, -1)
loss_cont = ((eps_pred - noise) ** 2 * weights).mean()
else:
loss_cont = F.mse_loss(eps_pred, noise)
loss_disc = 0.0
loss_disc_count = 0
for i, logit in enumerate(logits):