Fix quantile transform scaling and document

This commit is contained in:
2026-01-27 19:27:00 +08:00
parent 80e91443d2
commit 9e1e7338a2
4 changed files with 15 additions and 5 deletions

View File

@@ -271,11 +271,12 @@ def main():
if args.clip_k > 0:
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
x_cont = x_cont * std_vec + mean_vec
if use_quantile:
x_cont = inverse_quantile_transform(x_cont, cont_cols, quantile_probs, quantile_values)
else:
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
x_cont = x_cont * std_vec + mean_vec
for i, c in enumerate(cont_cols):
if transforms.get(c) == "log1p":
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])