diff --git a/example/train.py b/example/train.py index 0d14179..0852fbe 100755 --- a/example/train.py +++ b/example/train.py @@ -355,7 +355,7 @@ def main(): if config.get("cont_loss_weighting") == "inv_std": weights = torch.tensor( - [1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in cont_cols], + [1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in model_cont_cols], device=device, dtype=eps_pred.dtype, ).view(1, 1, -1)