This commit is contained in:
2026-01-22 21:17:11 +08:00
parent 5a109f91ac
commit 178fb7441c
4 changed files with 102 additions and 12 deletions

View File

@@ -27,20 +27,41 @@ def compute_cont_stats(
path: Union[str, List[str]],
cont_cols: List[str],
max_rows: Optional[int] = None,
) -> Tuple[Dict[str, float], Dict[str, float]]:
"""Streaming mean/std (Welford)."""
) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, bool], Dict[str, int]]:
"""Streaming mean/std (Welford) + min/max + int/precision metadata."""
count = 0
mean = {c: 0.0 for c in cont_cols}
m2 = {c: 0.0 for c in cont_cols}
vmin = {c: float("inf") for c in cont_cols}
vmax = {c: float("-inf") for c in cont_cols}
int_like = {c: True for c in cont_cols}
max_decimals = {c: 0 for c in cont_cols}
for i, row in enumerate(iter_rows(path)):
count += 1
for c in cont_cols:
x = float(row[c])
raw = row[c]
if raw is None or raw == "":
continue
x = float(raw)
delta = x - mean[c]
mean[c] += delta / count
delta2 = x - mean[c]
m2[c] += delta * delta2
if x < vmin[c]:
vmin[c] = x
if x > vmax[c]:
vmax[c] = x
if int_like[c] and abs(x - round(x)) > 1e-9:
int_like[c] = False
# track decimal places from raw string if possible
if "e" in raw or "E" in raw:
# scientific notation, skip precision inference
continue
if "." in raw:
dec = raw.split(".", 1)[1].rstrip("0")
if len(dec) > max_decimals[c]:
max_decimals[c] = len(dec)
if max_rows is not None and i + 1 >= max_rows:
break
@@ -51,7 +72,13 @@ def compute_cont_stats(
else:
var = 0.0
std[c] = var ** 0.5 if var > 0 else 1.0
return mean, std
# replace infs if column had no valid values
for c in cont_cols:
if vmin[c] == float("inf"):
vmin[c] = 0.0
if vmax[c] == float("-inf"):
vmax[c] = 0.0
return mean, std, vmin, vmax, int_like, max_decimals
def build_vocab(
@@ -75,6 +102,34 @@ def build_vocab(
return vocab
def build_disc_stats(
path: Union[str, List[str]],
disc_cols: List[str],
max_rows: Optional[int] = None,
) -> Tuple[Dict[str, Dict[str, int]], Dict[str, str]]:
counts = {c: {} for c in disc_cols}
for i, row in enumerate(iter_rows(path)):
for c in disc_cols:
val = row[c]
counts[c][val] = counts[c].get(val, 0) + 1
if max_rows is not None and i + 1 >= max_rows:
break
vocab = {}
top_token = {}
for c in disc_cols:
tokens = sorted(counts[c].keys())
if "<UNK>" not in tokens:
tokens.append("<UNK>")
vocab[c] = {tok: idx for idx, tok in enumerate(tokens)}
# most frequent token
if counts[c]:
top_token[c] = max(counts[c].items(), key=lambda kv: kv[1])[0]
else:
top_token[c] = "<UNK>"
return vocab, top_token
def normalize_cont(x, cont_cols: List[str], mean: Dict[str, float], std: Dict[str, float]):
import torch
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)