463 lines
15 KiB
Python
Executable File
463 lines
15 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Small utilities for HAI 21.03 data loading and feature encoding."""
|
|
|
|
import csv
|
|
import gzip
|
|
import json
|
|
import math
|
|
import random
|
|
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
|
|
|
|
|
|
|
def load_split(path: str) -> Dict[str, List[str]]:
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
|
|
|
|
def iter_rows(path_or_paths: Union[str, List[str]]) -> Iterable[Dict[str, str]]:
|
|
paths = [path_or_paths] if isinstance(path_or_paths, str) else list(path_or_paths)
|
|
for path in paths:
|
|
opener = gzip.open if str(path).endswith(".gz") else open
|
|
with opener(path, "rt", newline="") as f:
|
|
reader = csv.DictReader(f)
|
|
for row in reader:
|
|
yield row
|
|
|
|
|
|
def _stream_basic_stats(
|
|
path: Union[str, List[str]],
|
|
cont_cols: List[str],
|
|
max_rows: Optional[int] = None,
|
|
):
|
|
"""Streaming stats with mean/M2/M3 + min/max + int/precision metadata."""
|
|
count = {c: 0 for c in cont_cols}
|
|
mean = {c: 0.0 for c in cont_cols}
|
|
m2 = {c: 0.0 for c in cont_cols}
|
|
m3 = {c: 0.0 for c in cont_cols}
|
|
vmin = {c: float("inf") for c in cont_cols}
|
|
vmax = {c: float("-inf") for c in cont_cols}
|
|
int_like = {c: True for c in cont_cols}
|
|
max_decimals = {c: 0 for c in cont_cols}
|
|
all_pos = {c: True for c in cont_cols}
|
|
|
|
for i, row in enumerate(iter_rows(path)):
|
|
for c in cont_cols:
|
|
raw = row[c]
|
|
if raw is None or raw == "":
|
|
continue
|
|
x = float(raw)
|
|
if x <= 0:
|
|
all_pos[c] = False
|
|
if x < vmin[c]:
|
|
vmin[c] = x
|
|
if x > vmax[c]:
|
|
vmax[c] = x
|
|
if int_like[c] and abs(x - round(x)) > 1e-9:
|
|
int_like[c] = False
|
|
if "e" not in raw and "E" not in raw and "." in raw:
|
|
dec = raw.split(".", 1)[1].rstrip("0")
|
|
if len(dec) > max_decimals[c]:
|
|
max_decimals[c] = len(dec)
|
|
|
|
n = count[c] + 1
|
|
delta = x - mean[c]
|
|
delta_n = delta / n
|
|
term1 = delta * delta_n * (n - 1)
|
|
m3[c] += term1 * delta_n * (n - 2) - 3 * delta_n * m2[c]
|
|
m2[c] += term1
|
|
mean[c] += delta_n
|
|
count[c] = n
|
|
|
|
if max_rows is not None and i + 1 >= max_rows:
|
|
break
|
|
|
|
# finalize std/skew
|
|
std = {}
|
|
skew = {}
|
|
for c in cont_cols:
|
|
n = count[c]
|
|
if n > 1:
|
|
var = m2[c] / (n - 1)
|
|
else:
|
|
var = 0.0
|
|
std[c] = var ** 0.5 if var > 0 else 1.0
|
|
if n > 2 and m2[c] > 0:
|
|
skew[c] = (math.sqrt(n) * (m3[c] / n)) / (m2[c] ** 1.5)
|
|
else:
|
|
skew[c] = 0.0
|
|
|
|
for c in cont_cols:
|
|
if vmin[c] == float("inf"):
|
|
vmin[c] = 0.0
|
|
if vmax[c] == float("-inf"):
|
|
vmax[c] = 0.0
|
|
|
|
return {
|
|
"count": count,
|
|
"mean": mean,
|
|
"std": std,
|
|
"m2": m2,
|
|
"m3": m3,
|
|
"min": vmin,
|
|
"max": vmax,
|
|
"int_like": int_like,
|
|
"max_decimals": max_decimals,
|
|
"skew": skew,
|
|
"all_pos": all_pos,
|
|
}
|
|
|
|
|
|
def choose_cont_transforms(
|
|
path: Union[str, List[str]],
|
|
cont_cols: List[str],
|
|
max_rows: Optional[int] = None,
|
|
skew_threshold: float = 1.5,
|
|
range_ratio_threshold: float = 1e3,
|
|
):
|
|
"""Pick per-column transform (currently log1p or none) based on skew/range."""
|
|
stats = _stream_basic_stats(path, cont_cols, max_rows=max_rows)
|
|
transforms = {}
|
|
for c in cont_cols:
|
|
if not stats["all_pos"][c]:
|
|
transforms[c] = "none"
|
|
continue
|
|
skew = abs(stats["skew"][c])
|
|
vmin = stats["min"][c]
|
|
vmax = stats["max"][c]
|
|
ratio = (vmax / vmin) if vmin > 0 else 0.0
|
|
if skew >= skew_threshold or ratio >= range_ratio_threshold:
|
|
transforms[c] = "log1p"
|
|
else:
|
|
transforms[c] = "none"
|
|
return transforms, stats
|
|
|
|
|
|
def compute_cont_stats(
|
|
path: Union[str, List[str]],
|
|
cont_cols: List[str],
|
|
max_rows: Optional[int] = None,
|
|
transforms: Optional[Dict[str, str]] = None,
|
|
quantile_bins: Optional[int] = None,
|
|
):
|
|
"""Compute stats on (optionally transformed) values. Returns raw + transformed stats."""
|
|
# First pass (raw) for metadata and raw mean/std
|
|
raw = _stream_basic_stats(path, cont_cols, max_rows=max_rows)
|
|
|
|
# Optional transform selection
|
|
if transforms is None:
|
|
transforms = {c: "none" for c in cont_cols}
|
|
|
|
# Second pass for transformed mean/std (and optional quantiles)
|
|
count = {c: 0 for c in cont_cols}
|
|
mean = {c: 0.0 for c in cont_cols}
|
|
m2 = {c: 0.0 for c in cont_cols}
|
|
quantile_values = {c: [] for c in cont_cols} if quantile_bins and quantile_bins > 1 else None
|
|
for i, row in enumerate(iter_rows(path)):
|
|
for c in cont_cols:
|
|
raw_val = row[c]
|
|
if raw_val is None or raw_val == "":
|
|
continue
|
|
x = float(raw_val)
|
|
if transforms.get(c) == "log1p":
|
|
if x < 0:
|
|
x = 0.0
|
|
x = math.log1p(x)
|
|
if quantile_values is not None:
|
|
quantile_values[c].append(x)
|
|
n = count[c] + 1
|
|
delta = x - mean[c]
|
|
mean[c] += delta / n
|
|
delta2 = x - mean[c]
|
|
m2[c] += delta * delta2
|
|
count[c] = n
|
|
if max_rows is not None and i + 1 >= max_rows:
|
|
break
|
|
|
|
std = {}
|
|
for c in cont_cols:
|
|
if count[c] > 1:
|
|
var = m2[c] / (count[c] - 1)
|
|
else:
|
|
var = 0.0
|
|
std[c] = var ** 0.5 if var > 0 else 1.0
|
|
|
|
quantile_probs = None
|
|
quantile_table = None
|
|
if quantile_values is not None:
|
|
quantile_probs = [i / (quantile_bins - 1) for i in range(quantile_bins)]
|
|
quantile_table = {}
|
|
for c in cont_cols:
|
|
vals = quantile_values[c]
|
|
if not vals:
|
|
quantile_table[c] = [0.0 for _ in quantile_probs]
|
|
continue
|
|
vals.sort()
|
|
n = len(vals)
|
|
qvals = []
|
|
for p in quantile_probs:
|
|
idx = int(round(p * (n - 1)))
|
|
idx = max(0, min(n - 1, idx))
|
|
qvals.append(float(vals[idx]))
|
|
quantile_table[c] = qvals
|
|
|
|
return {
|
|
"mean": mean,
|
|
"std": std,
|
|
"raw_mean": raw["mean"],
|
|
"raw_std": raw["std"],
|
|
"min": raw["min"],
|
|
"max": raw["max"],
|
|
"int_like": raw["int_like"],
|
|
"max_decimals": raw["max_decimals"],
|
|
"transform": transforms,
|
|
"skew": raw["skew"],
|
|
"all_pos": raw["all_pos"],
|
|
"max_rows": max_rows,
|
|
"quantile_probs": quantile_probs,
|
|
"quantile_values": quantile_table,
|
|
}
|
|
|
|
|
|
def build_vocab(
|
|
path: Union[str, List[str]],
|
|
disc_cols: List[str],
|
|
max_rows: Optional[int] = None,
|
|
) -> Dict[str, Dict[str, int]]:
|
|
values = {c: set() for c in disc_cols}
|
|
for i, row in enumerate(iter_rows(path)):
|
|
for c in disc_cols:
|
|
values[c].add(row[c])
|
|
if max_rows is not None and i + 1 >= max_rows:
|
|
break
|
|
|
|
vocab = {}
|
|
for c in disc_cols:
|
|
tokens = sorted(values[c])
|
|
if "<UNK>" not in tokens:
|
|
tokens.append("<UNK>")
|
|
vocab[c] = {tok: idx for idx, tok in enumerate(tokens)}
|
|
return vocab
|
|
|
|
|
|
def build_disc_stats(
|
|
path: Union[str, List[str]],
|
|
disc_cols: List[str],
|
|
max_rows: Optional[int] = None,
|
|
) -> Tuple[Dict[str, Dict[str, int]], Dict[str, str]]:
|
|
counts = {c: {} for c in disc_cols}
|
|
for i, row in enumerate(iter_rows(path)):
|
|
for c in disc_cols:
|
|
val = row[c]
|
|
counts[c][val] = counts[c].get(val, 0) + 1
|
|
if max_rows is not None and i + 1 >= max_rows:
|
|
break
|
|
|
|
vocab = {}
|
|
top_token = {}
|
|
for c in disc_cols:
|
|
tokens = sorted(counts[c].keys())
|
|
if "<UNK>" not in tokens:
|
|
tokens.append("<UNK>")
|
|
vocab[c] = {tok: idx for idx, tok in enumerate(tokens)}
|
|
# most frequent token
|
|
if counts[c]:
|
|
top_token[c] = max(counts[c].items(), key=lambda kv: kv[1])[0]
|
|
else:
|
|
top_token[c] = "<UNK>"
|
|
return vocab, top_token
|
|
|
|
|
|
def normalize_cont(
|
|
x,
|
|
cont_cols: List[str],
|
|
mean: Dict[str, float],
|
|
std: Dict[str, float],
|
|
transforms: Optional[Dict[str, str]] = None,
|
|
quantile_probs: Optional[List[float]] = None,
|
|
quantile_values: Optional[Dict[str, List[float]]] = None,
|
|
use_quantile: bool = False,
|
|
):
|
|
import torch
|
|
|
|
if transforms:
|
|
for i, c in enumerate(cont_cols):
|
|
if transforms.get(c) == "log1p":
|
|
x[:, :, i] = torch.log1p(torch.clamp(x[:, :, i], min=0))
|
|
if use_quantile:
|
|
if not quantile_probs or not quantile_values:
|
|
raise ValueError("use_quantile_transform enabled but quantile stats missing")
|
|
x = apply_quantile_transform(x, cont_cols, quantile_probs, quantile_values)
|
|
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)
|
|
std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device)
|
|
return (x - mean_t) / std_t
|
|
|
|
|
|
def _normal_cdf(x):
|
|
import torch
|
|
return 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
|
|
|
|
|
def _normal_ppf(p):
|
|
import torch
|
|
eps = 1e-6
|
|
p = torch.clamp(p, eps, 1.0 - eps)
|
|
return math.sqrt(2.0) * torch.erfinv(2.0 * p - 1.0)
|
|
|
|
|
|
def apply_quantile_transform(x, cont_cols, quantile_probs, quantile_values):
|
|
import torch
|
|
probs_t = torch.tensor(quantile_probs, dtype=x.dtype, device=x.device)
|
|
for i, c in enumerate(cont_cols):
|
|
q_vals = torch.tensor(quantile_values[c], dtype=x.dtype, device=x.device)
|
|
v = x[:, :, i]
|
|
idx = torch.bucketize(v, q_vals)
|
|
idx = torch.clamp(idx, 1, q_vals.numel() - 1)
|
|
x0 = q_vals[idx - 1]
|
|
x1 = q_vals[idx]
|
|
p0 = probs_t[idx - 1]
|
|
p1 = probs_t[idx]
|
|
denom = torch.where((x1 - x0) == 0, torch.ones_like(x1 - x0), (x1 - x0))
|
|
p = p0 + (v - x0) * (p1 - p0) / denom
|
|
x[:, :, i] = _normal_ppf(p)
|
|
return x
|
|
|
|
|
|
def inverse_quantile_transform(x, cont_cols, quantile_probs, quantile_values):
|
|
import torch
|
|
probs_t = torch.tensor(quantile_probs, dtype=x.dtype, device=x.device)
|
|
for i, c in enumerate(cont_cols):
|
|
q_vals = torch.tensor(quantile_values[c], dtype=x.dtype, device=x.device)
|
|
z = x[:, :, i]
|
|
p = _normal_cdf(z)
|
|
idx = torch.bucketize(p, probs_t)
|
|
idx = torch.clamp(idx, 1, probs_t.numel() - 1)
|
|
p0 = probs_t[idx - 1]
|
|
p1 = probs_t[idx]
|
|
x0 = q_vals[idx - 1]
|
|
x1 = q_vals[idx]
|
|
denom = torch.where((p1 - p0) == 0, torch.ones_like(p1 - p0), (p1 - p0))
|
|
v = x0 + (p - p0) * (x1 - x0) / denom
|
|
x[:, :, i] = v
|
|
return x
|
|
|
|
|
|
def windowed_batches(
|
|
path: Union[str, List[str]],
|
|
cont_cols: List[str],
|
|
disc_cols: List[str],
|
|
vocab: Dict[str, Dict[str, int]],
|
|
mean: Dict[str, float],
|
|
std: Dict[str, float],
|
|
batch_size: int,
|
|
seq_len: int,
|
|
max_batches: Optional[int] = None,
|
|
return_file_id: bool = False,
|
|
transforms: Optional[Dict[str, str]] = None,
|
|
quantile_probs: Optional[List[float]] = None,
|
|
quantile_values: Optional[Dict[str, List[float]]] = None,
|
|
use_quantile: bool = False,
|
|
shuffle_buffer: int = 0,
|
|
):
|
|
import torch
|
|
batch_cont = []
|
|
batch_disc = []
|
|
batch_file = []
|
|
buffer = []
|
|
seq_cont = []
|
|
seq_disc = []
|
|
|
|
def flush_seq(file_id: int):
|
|
nonlocal seq_cont, seq_disc, batch_cont, batch_disc, batch_file
|
|
if len(seq_cont) == seq_len:
|
|
if shuffle_buffer and shuffle_buffer > 0:
|
|
buffer.append((list(seq_cont), list(seq_disc), file_id))
|
|
if len(buffer) >= shuffle_buffer:
|
|
idx = random.randrange(len(buffer))
|
|
seq_c, seq_d, seq_f = buffer.pop(idx)
|
|
batch_cont.append(seq_c)
|
|
batch_disc.append(seq_d)
|
|
if return_file_id:
|
|
batch_file.append(seq_f)
|
|
else:
|
|
batch_cont.append(seq_cont)
|
|
batch_disc.append(seq_disc)
|
|
if return_file_id:
|
|
batch_file.append(file_id)
|
|
seq_cont = []
|
|
seq_disc = []
|
|
|
|
batches_yielded = 0
|
|
paths = [path] if isinstance(path, str) else list(path)
|
|
for file_id, p in enumerate(paths):
|
|
for row in iter_rows(p):
|
|
cont_row = [float(row[c]) for c in cont_cols]
|
|
disc_row = [vocab[c].get(row[c], vocab[c]["<UNK>"]) for c in disc_cols]
|
|
seq_cont.append(cont_row)
|
|
seq_disc.append(disc_row)
|
|
if len(seq_cont) == seq_len:
|
|
flush_seq(file_id)
|
|
if len(batch_cont) == batch_size:
|
|
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
|
|
x_disc = torch.tensor(batch_disc, dtype=torch.long)
|
|
x_cont = normalize_cont(
|
|
x_cont,
|
|
cont_cols,
|
|
mean,
|
|
std,
|
|
transforms=transforms,
|
|
quantile_probs=quantile_probs,
|
|
quantile_values=quantile_values,
|
|
use_quantile=use_quantile,
|
|
)
|
|
if return_file_id:
|
|
x_file = torch.tensor(batch_file, dtype=torch.long)
|
|
yield x_cont, x_disc, x_file
|
|
else:
|
|
yield x_cont, x_disc
|
|
batch_cont = []
|
|
batch_disc = []
|
|
batch_file = []
|
|
batches_yielded += 1
|
|
if max_batches is not None and batches_yielded >= max_batches:
|
|
return
|
|
# drop partial sequence at file boundary
|
|
seq_cont = []
|
|
seq_disc = []
|
|
|
|
# Flush any remaining buffered sequences
|
|
if shuffle_buffer and buffer:
|
|
random.shuffle(buffer)
|
|
for seq_c, seq_d, seq_f in buffer:
|
|
batch_cont.append(seq_c)
|
|
batch_disc.append(seq_d)
|
|
if return_file_id:
|
|
batch_file.append(seq_f)
|
|
if len(batch_cont) == batch_size:
|
|
import torch
|
|
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
|
|
x_disc = torch.tensor(batch_disc, dtype=torch.long)
|
|
x_cont = normalize_cont(
|
|
x_cont,
|
|
cont_cols,
|
|
mean,
|
|
std,
|
|
transforms=transforms,
|
|
quantile_probs=quantile_probs,
|
|
quantile_values=quantile_values,
|
|
use_quantile=use_quantile,
|
|
)
|
|
if return_file_id:
|
|
x_file = torch.tensor(batch_file, dtype=torch.long)
|
|
yield x_cont, x_disc, x_file
|
|
else:
|
|
yield x_cont, x_disc
|
|
batch_cont = []
|
|
batch_disc = []
|
|
batch_file = []
|
|
batches_yielded += 1
|
|
if max_batches is not None and batches_yielded >= max_batches:
|
|
return
|
|
|
|
# Drop last partial batch for simplicity
|