This commit is contained in:
2026-01-27 18:39:24 +08:00
parent c46c25d607
commit a24c60c506
22 changed files with 357 additions and 8 deletions

View File

@@ -138,6 +138,7 @@ def compute_cont_stats(
cont_cols: List[str],
max_rows: Optional[int] = None,
transforms: Optional[Dict[str, str]] = None,
quantile_bins: Optional[int] = None,
):
"""Compute stats on (optionally transformed) values. Returns raw + transformed stats."""
# First pass (raw) for metadata and raw mean/std
@@ -147,10 +148,11 @@ def compute_cont_stats(
if transforms is None:
transforms = {c: "none" for c in cont_cols}
# Second pass for transformed mean/std
# Second pass for transformed mean/std (and optional quantiles)
count = {c: 0 for c in cont_cols}
mean = {c: 0.0 for c in cont_cols}
m2 = {c: 0.0 for c in cont_cols}
quantile_values = {c: [] for c in cont_cols} if quantile_bins and quantile_bins > 1 else None
for i, row in enumerate(iter_rows(path)):
for c in cont_cols:
raw_val = row[c]
@@ -161,6 +163,8 @@ def compute_cont_stats(
if x < 0:
x = 0.0
x = math.log1p(x)
if quantile_values is not None:
quantile_values[c].append(x)
n = count[c] + 1
delta = x - mean[c]
mean[c] += delta / n
@@ -178,6 +182,25 @@ def compute_cont_stats(
var = 0.0
std[c] = var ** 0.5 if var > 0 else 1.0
quantile_probs = None
quantile_table = None
if quantile_values is not None:
quantile_probs = [i / (quantile_bins - 1) for i in range(quantile_bins)]
quantile_table = {}
for c in cont_cols:
vals = quantile_values[c]
if not vals:
quantile_table[c] = [0.0 for _ in quantile_probs]
continue
vals.sort()
n = len(vals)
qvals = []
for p in quantile_probs:
idx = int(round(p * (n - 1)))
idx = max(0, min(n - 1, idx))
qvals.append(float(vals[idx]))
quantile_table[c] = qvals
return {
"mean": mean,
"std": std,
@@ -191,6 +214,8 @@ def compute_cont_stats(
"skew": raw["skew"],
"all_pos": raw["all_pos"],
"max_rows": max_rows,
"quantile_probs": quantile_probs,
"quantile_values": quantile_table,
}
@@ -249,6 +274,9 @@ def normalize_cont(
mean: Dict[str, float],
std: Dict[str, float],
transforms: Optional[Dict[str, str]] = None,
quantile_probs: Optional[List[float]] = None,
quantile_values: Optional[Dict[str, List[float]]] = None,
use_quantile: bool = False,
):
import torch
@@ -256,11 +284,64 @@ def normalize_cont(
for i, c in enumerate(cont_cols):
if transforms.get(c) == "log1p":
x[:, :, i] = torch.log1p(torch.clamp(x[:, :, i], min=0))
if use_quantile:
if not quantile_probs or not quantile_values:
raise ValueError("use_quantile_transform enabled but quantile stats missing")
x = apply_quantile_transform(x, cont_cols, quantile_probs, quantile_values)
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)
std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device)
return (x - mean_t) / std_t
def _normal_cdf(x):
import torch
return 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def _normal_ppf(p):
import torch
eps = 1e-6
p = torch.clamp(p, eps, 1.0 - eps)
return math.sqrt(2.0) * torch.erfinv(2.0 * p - 1.0)
def apply_quantile_transform(x, cont_cols, quantile_probs, quantile_values):
import torch
probs_t = torch.tensor(quantile_probs, dtype=x.dtype, device=x.device)
for i, c in enumerate(cont_cols):
q_vals = torch.tensor(quantile_values[c], dtype=x.dtype, device=x.device)
v = x[:, :, i]
idx = torch.bucketize(v, q_vals)
idx = torch.clamp(idx, 1, q_vals.numel() - 1)
x0 = q_vals[idx - 1]
x1 = q_vals[idx]
p0 = probs_t[idx - 1]
p1 = probs_t[idx]
denom = torch.where((x1 - x0) == 0, torch.ones_like(x1 - x0), (x1 - x0))
p = p0 + (v - x0) * (p1 - p0) / denom
x[:, :, i] = _normal_ppf(p)
return x
def inverse_quantile_transform(x, cont_cols, quantile_probs, quantile_values):
import torch
probs_t = torch.tensor(quantile_probs, dtype=x.dtype, device=x.device)
for i, c in enumerate(cont_cols):
q_vals = torch.tensor(quantile_values[c], dtype=x.dtype, device=x.device)
z = x[:, :, i]
p = _normal_cdf(z)
idx = torch.bucketize(p, probs_t)
idx = torch.clamp(idx, 1, probs_t.numel() - 1)
p0 = probs_t[idx - 1]
p1 = probs_t[idx]
x0 = q_vals[idx - 1]
x1 = q_vals[idx]
denom = torch.where((p1 - p0) == 0, torch.ones_like(p1 - p0), (p1 - p0))
v = x0 + (p - p0) * (x1 - x0) / denom
x[:, :, i] = v
return x
def windowed_batches(
path: Union[str, List[str]],
cont_cols: List[str],
@@ -273,6 +354,9 @@ def windowed_batches(
max_batches: Optional[int] = None,
return_file_id: bool = False,
transforms: Optional[Dict[str, str]] = None,
quantile_probs: Optional[List[float]] = None,
quantile_values: Optional[Dict[str, List[float]]] = None,
use_quantile: bool = False,
shuffle_buffer: int = 0,
):
import torch
@@ -316,7 +400,16 @@ def windowed_batches(
if len(batch_cont) == batch_size:
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
x_disc = torch.tensor(batch_disc, dtype=torch.long)
x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms)
x_cont = normalize_cont(
x_cont,
cont_cols,
mean,
std,
transforms=transforms,
quantile_probs=quantile_probs,
quantile_values=quantile_values,
use_quantile=use_quantile,
)
if return_file_id:
x_file = torch.tensor(batch_file, dtype=torch.long)
yield x_cont, x_disc, x_file
@@ -344,7 +437,16 @@ def windowed_batches(
import torch
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
x_disc = torch.tensor(batch_disc, dtype=torch.long)
x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms)
x_cont = normalize_cont(
x_cont,
cont_cols,
mean,
std,
transforms=transforms,
quantile_probs=quantile_probs,
quantile_values=quantile_values,
use_quantile=use_quantile,
)
if return_file_id:
x_file = torch.tensor(batch_file, dtype=torch.long)
yield x_cont, x_disc, x_file