Fix cont loss weighting for filtered cont dims
This commit is contained in:
@@ -355,7 +355,7 @@ def main():
|
|||||||
|
|
||||||
if config.get("cont_loss_weighting") == "inv_std":
|
if config.get("cont_loss_weighting") == "inv_std":
|
||||||
weights = torch.tensor(
|
weights = torch.tensor(
|
||||||
[1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in cont_cols],
|
[1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in model_cont_cols],
|
||||||
device=device,
|
device=device,
|
||||||
dtype=eps_pred.dtype,
|
dtype=eps_pred.dtype,
|
||||||
).view(1, 1, -1)
|
).view(1, 1, -1)
|
||||||
|
|||||||
Reference in New Issue
Block a user