Fix quantile transform scaling and document

This commit is contained in:
2026-01-27 19:27:00 +08:00
parent 80e91443d2
commit 9e1e7338a2
4 changed files with 15 additions and 5 deletions

View File

@@ -288,6 +288,8 @@ def normalize_cont(
if not quantile_probs or not quantile_values:
raise ValueError("use_quantile_transform enabled but quantile stats missing")
x = apply_quantile_transform(x, cont_cols, quantile_probs, quantile_values)
# quantile transform already targets N(0,1); skip extra standardization
return x
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)
std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device)
return (x - mean_t) / std_t

View File

@@ -271,11 +271,12 @@ def main():
if args.clip_k > 0:
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
x_cont = x_cont * std_vec + mean_vec
if use_quantile:
x_cont = inverse_quantile_transform(x_cont, cont_cols, quantile_probs, quantile_values)
else:
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
x_cont = x_cont * std_vec + mean_vec
for i, c in enumerate(cont_cols):
if transforms.get(c) == "log1p":
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])