update2
This commit is contained in:
@@ -138,6 +138,7 @@ def compute_cont_stats(
|
||||
cont_cols: List[str],
|
||||
max_rows: Optional[int] = None,
|
||||
transforms: Optional[Dict[str, str]] = None,
|
||||
quantile_bins: Optional[int] = None,
|
||||
):
|
||||
"""Compute stats on (optionally transformed) values. Returns raw + transformed stats."""
|
||||
# First pass (raw) for metadata and raw mean/std
|
||||
@@ -147,10 +148,11 @@ def compute_cont_stats(
|
||||
if transforms is None:
|
||||
transforms = {c: "none" for c in cont_cols}
|
||||
|
||||
# Second pass for transformed mean/std
|
||||
# Second pass for transformed mean/std (and optional quantiles)
|
||||
count = {c: 0 for c in cont_cols}
|
||||
mean = {c: 0.0 for c in cont_cols}
|
||||
m2 = {c: 0.0 for c in cont_cols}
|
||||
quantile_values = {c: [] for c in cont_cols} if quantile_bins and quantile_bins > 1 else None
|
||||
for i, row in enumerate(iter_rows(path)):
|
||||
for c in cont_cols:
|
||||
raw_val = row[c]
|
||||
@@ -161,6 +163,8 @@ def compute_cont_stats(
|
||||
if x < 0:
|
||||
x = 0.0
|
||||
x = math.log1p(x)
|
||||
if quantile_values is not None:
|
||||
quantile_values[c].append(x)
|
||||
n = count[c] + 1
|
||||
delta = x - mean[c]
|
||||
mean[c] += delta / n
|
||||
@@ -178,6 +182,25 @@ def compute_cont_stats(
|
||||
var = 0.0
|
||||
std[c] = var ** 0.5 if var > 0 else 1.0
|
||||
|
||||
quantile_probs = None
|
||||
quantile_table = None
|
||||
if quantile_values is not None:
|
||||
quantile_probs = [i / (quantile_bins - 1) for i in range(quantile_bins)]
|
||||
quantile_table = {}
|
||||
for c in cont_cols:
|
||||
vals = quantile_values[c]
|
||||
if not vals:
|
||||
quantile_table[c] = [0.0 for _ in quantile_probs]
|
||||
continue
|
||||
vals.sort()
|
||||
n = len(vals)
|
||||
qvals = []
|
||||
for p in quantile_probs:
|
||||
idx = int(round(p * (n - 1)))
|
||||
idx = max(0, min(n - 1, idx))
|
||||
qvals.append(float(vals[idx]))
|
||||
quantile_table[c] = qvals
|
||||
|
||||
return {
|
||||
"mean": mean,
|
||||
"std": std,
|
||||
@@ -191,6 +214,8 @@ def compute_cont_stats(
|
||||
"skew": raw["skew"],
|
||||
"all_pos": raw["all_pos"],
|
||||
"max_rows": max_rows,
|
||||
"quantile_probs": quantile_probs,
|
||||
"quantile_values": quantile_table,
|
||||
}
|
||||
|
||||
|
||||
@@ -249,6 +274,9 @@ def normalize_cont(
|
||||
mean: Dict[str, float],
|
||||
std: Dict[str, float],
|
||||
transforms: Optional[Dict[str, str]] = None,
|
||||
quantile_probs: Optional[List[float]] = None,
|
||||
quantile_values: Optional[Dict[str, List[float]]] = None,
|
||||
use_quantile: bool = False,
|
||||
):
|
||||
import torch
|
||||
|
||||
@@ -256,11 +284,64 @@ def normalize_cont(
|
||||
for i, c in enumerate(cont_cols):
|
||||
if transforms.get(c) == "log1p":
|
||||
x[:, :, i] = torch.log1p(torch.clamp(x[:, :, i], min=0))
|
||||
if use_quantile:
|
||||
if not quantile_probs or not quantile_values:
|
||||
raise ValueError("use_quantile_transform enabled but quantile stats missing")
|
||||
x = apply_quantile_transform(x, cont_cols, quantile_probs, quantile_values)
|
||||
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)
|
||||
std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device)
|
||||
return (x - mean_t) / std_t
|
||||
|
||||
|
||||
def _normal_cdf(x):
|
||||
import torch
|
||||
return 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def _normal_ppf(p):
|
||||
import torch
|
||||
eps = 1e-6
|
||||
p = torch.clamp(p, eps, 1.0 - eps)
|
||||
return math.sqrt(2.0) * torch.erfinv(2.0 * p - 1.0)
|
||||
|
||||
|
||||
def apply_quantile_transform(x, cont_cols, quantile_probs, quantile_values):
|
||||
import torch
|
||||
probs_t = torch.tensor(quantile_probs, dtype=x.dtype, device=x.device)
|
||||
for i, c in enumerate(cont_cols):
|
||||
q_vals = torch.tensor(quantile_values[c], dtype=x.dtype, device=x.device)
|
||||
v = x[:, :, i]
|
||||
idx = torch.bucketize(v, q_vals)
|
||||
idx = torch.clamp(idx, 1, q_vals.numel() - 1)
|
||||
x0 = q_vals[idx - 1]
|
||||
x1 = q_vals[idx]
|
||||
p0 = probs_t[idx - 1]
|
||||
p1 = probs_t[idx]
|
||||
denom = torch.where((x1 - x0) == 0, torch.ones_like(x1 - x0), (x1 - x0))
|
||||
p = p0 + (v - x0) * (p1 - p0) / denom
|
||||
x[:, :, i] = _normal_ppf(p)
|
||||
return x
|
||||
|
||||
|
||||
def inverse_quantile_transform(x, cont_cols, quantile_probs, quantile_values):
|
||||
import torch
|
||||
probs_t = torch.tensor(quantile_probs, dtype=x.dtype, device=x.device)
|
||||
for i, c in enumerate(cont_cols):
|
||||
q_vals = torch.tensor(quantile_values[c], dtype=x.dtype, device=x.device)
|
||||
z = x[:, :, i]
|
||||
p = _normal_cdf(z)
|
||||
idx = torch.bucketize(p, probs_t)
|
||||
idx = torch.clamp(idx, 1, probs_t.numel() - 1)
|
||||
p0 = probs_t[idx - 1]
|
||||
p1 = probs_t[idx]
|
||||
x0 = q_vals[idx - 1]
|
||||
x1 = q_vals[idx]
|
||||
denom = torch.where((p1 - p0) == 0, torch.ones_like(p1 - p0), (p1 - p0))
|
||||
v = x0 + (p - p0) * (x1 - x0) / denom
|
||||
x[:, :, i] = v
|
||||
return x
|
||||
|
||||
|
||||
def windowed_batches(
|
||||
path: Union[str, List[str]],
|
||||
cont_cols: List[str],
|
||||
@@ -273,6 +354,9 @@ def windowed_batches(
|
||||
max_batches: Optional[int] = None,
|
||||
return_file_id: bool = False,
|
||||
transforms: Optional[Dict[str, str]] = None,
|
||||
quantile_probs: Optional[List[float]] = None,
|
||||
quantile_values: Optional[Dict[str, List[float]]] = None,
|
||||
use_quantile: bool = False,
|
||||
shuffle_buffer: int = 0,
|
||||
):
|
||||
import torch
|
||||
@@ -316,7 +400,16 @@ def windowed_batches(
|
||||
if len(batch_cont) == batch_size:
|
||||
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
|
||||
x_disc = torch.tensor(batch_disc, dtype=torch.long)
|
||||
x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms)
|
||||
x_cont = normalize_cont(
|
||||
x_cont,
|
||||
cont_cols,
|
||||
mean,
|
||||
std,
|
||||
transforms=transforms,
|
||||
quantile_probs=quantile_probs,
|
||||
quantile_values=quantile_values,
|
||||
use_quantile=use_quantile,
|
||||
)
|
||||
if return_file_id:
|
||||
x_file = torch.tensor(batch_file, dtype=torch.long)
|
||||
yield x_cont, x_disc, x_file
|
||||
@@ -344,7 +437,16 @@ def windowed_batches(
|
||||
import torch
|
||||
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
|
||||
x_disc = torch.tensor(batch_disc, dtype=torch.long)
|
||||
x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms)
|
||||
x_cont = normalize_cont(
|
||||
x_cont,
|
||||
cont_cols,
|
||||
mean,
|
||||
std,
|
||||
transforms=transforms,
|
||||
quantile_probs=quantile_probs,
|
||||
quantile_values=quantile_values,
|
||||
use_quantile=use_quantile,
|
||||
)
|
||||
if return_file_id:
|
||||
x_file = torch.tensor(batch_file, dtype=torch.long)
|
||||
yield x_cont, x_disc, x_file
|
||||
|
||||
Reference in New Issue
Block a user