Fix cont loss weighting for filtered cont dims

This commit is contained in:
2026-01-28 22:27:23 +08:00
parent 8db286792e
commit 6fb53dd5c1

View File

@@ -355,7 +355,7 @@ def main():
if config.get("cont_loss_weighting") == "inv_std": if config.get("cont_loss_weighting") == "inv_std":
weights = torch.tensor( weights = torch.tensor(
[1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in cont_cols], [1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in model_cont_cols],
device=device, device=device,
dtype=eps_pred.dtype, dtype=eps_pred.dtype,
).view(1, 1, -1) ).view(1, 1, -1)