update
This commit is contained in:
@@ -27,20 +27,41 @@ def compute_cont_stats(
|
||||
path: Union[str, List[str]],
|
||||
cont_cols: List[str],
|
||||
max_rows: Optional[int] = None,
|
||||
) -> Tuple[Dict[str, float], Dict[str, float]]:
|
||||
"""Streaming mean/std (Welford)."""
|
||||
) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, bool], Dict[str, int]]:
|
||||
"""Streaming mean/std (Welford) + min/max + int/precision metadata."""
|
||||
count = 0
|
||||
mean = {c: 0.0 for c in cont_cols}
|
||||
m2 = {c: 0.0 for c in cont_cols}
|
||||
vmin = {c: float("inf") for c in cont_cols}
|
||||
vmax = {c: float("-inf") for c in cont_cols}
|
||||
int_like = {c: True for c in cont_cols}
|
||||
max_decimals = {c: 0 for c in cont_cols}
|
||||
|
||||
for i, row in enumerate(iter_rows(path)):
|
||||
count += 1
|
||||
for c in cont_cols:
|
||||
x = float(row[c])
|
||||
raw = row[c]
|
||||
if raw is None or raw == "":
|
||||
continue
|
||||
x = float(raw)
|
||||
delta = x - mean[c]
|
||||
mean[c] += delta / count
|
||||
delta2 = x - mean[c]
|
||||
m2[c] += delta * delta2
|
||||
if x < vmin[c]:
|
||||
vmin[c] = x
|
||||
if x > vmax[c]:
|
||||
vmax[c] = x
|
||||
if int_like[c] and abs(x - round(x)) > 1e-9:
|
||||
int_like[c] = False
|
||||
# track decimal places from raw string if possible
|
||||
if "e" in raw or "E" in raw:
|
||||
# scientific notation, skip precision inference
|
||||
continue
|
||||
if "." in raw:
|
||||
dec = raw.split(".", 1)[1].rstrip("0")
|
||||
if len(dec) > max_decimals[c]:
|
||||
max_decimals[c] = len(dec)
|
||||
if max_rows is not None and i + 1 >= max_rows:
|
||||
break
|
||||
|
||||
@@ -51,7 +72,13 @@ def compute_cont_stats(
|
||||
else:
|
||||
var = 0.0
|
||||
std[c] = var ** 0.5 if var > 0 else 1.0
|
||||
return mean, std
|
||||
# replace infs if column had no valid values
|
||||
for c in cont_cols:
|
||||
if vmin[c] == float("inf"):
|
||||
vmin[c] = 0.0
|
||||
if vmax[c] == float("-inf"):
|
||||
vmax[c] = 0.0
|
||||
return mean, std, vmin, vmax, int_like, max_decimals
|
||||
|
||||
|
||||
def build_vocab(
|
||||
@@ -75,6 +102,34 @@ def build_vocab(
|
||||
return vocab
|
||||
|
||||
|
||||
def build_disc_stats(
|
||||
path: Union[str, List[str]],
|
||||
disc_cols: List[str],
|
||||
max_rows: Optional[int] = None,
|
||||
) -> Tuple[Dict[str, Dict[str, int]], Dict[str, str]]:
|
||||
counts = {c: {} for c in disc_cols}
|
||||
for i, row in enumerate(iter_rows(path)):
|
||||
for c in disc_cols:
|
||||
val = row[c]
|
||||
counts[c][val] = counts[c].get(val, 0) + 1
|
||||
if max_rows is not None and i + 1 >= max_rows:
|
||||
break
|
||||
|
||||
vocab = {}
|
||||
top_token = {}
|
||||
for c in disc_cols:
|
||||
tokens = sorted(counts[c].keys())
|
||||
if "<UNK>" not in tokens:
|
||||
tokens.append("<UNK>")
|
||||
vocab[c] = {tok: idx for idx, tok in enumerate(tokens)}
|
||||
# most frequent token
|
||||
if counts[c]:
|
||||
top_token[c] = max(counts[c].items(), key=lambda kv: kv[1])[0]
|
||||
else:
|
||||
top_token[c] = "<UNK>"
|
||||
return vocab, top_token
|
||||
|
||||
|
||||
def normalize_cont(x, cont_cols: List[str], mean: Dict[str, float], std: Dict[str, float]):
|
||||
import torch
|
||||
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)
|
||||
|
||||
Reference in New Issue
Block a user