Fix cont loss weighting for filtered cont dims
This commit is contained in:
@@ -355,7 +355,7 @@ def main():
|
||||
|
||||
if config.get("cont_loss_weighting") == "inv_std":
|
||||
weights = torch.tensor(
|
||||
[1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in cont_cols],
|
||||
[1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in model_cont_cols],
|
||||
device=device,
|
||||
dtype=eps_pred.dtype,
|
||||
).view(1, 1, -1)
|
||||
|
||||
Reference in New Issue
Block a user