#!/usr/bin/env python3 """Small utilities for HAI 21.03 data loading and feature encoding.""" import csv import gzip import json import math import random from typing import Dict, Iterable, List, Optional, Tuple, Union def load_split(path: str) -> Dict[str, List[str]]: with open(path, "r", encoding="utf-8") as f: return json.load(f) def iter_rows(path_or_paths: Union[str, List[str]]) -> Iterable[Dict[str, str]]: paths = [path_or_paths] if isinstance(path_or_paths, str) else list(path_or_paths) for path in paths: opener = gzip.open if str(path).endswith(".gz") else open with opener(path, "rt", newline="") as f: reader = csv.DictReader(f) for row in reader: yield row def _stream_basic_stats( path: Union[str, List[str]], cont_cols: List[str], max_rows: Optional[int] = None, ): """Streaming stats with mean/M2/M3 + min/max + int/precision metadata.""" count = {c: 0 for c in cont_cols} mean = {c: 0.0 for c in cont_cols} m2 = {c: 0.0 for c in cont_cols} m3 = {c: 0.0 for c in cont_cols} vmin = {c: float("inf") for c in cont_cols} vmax = {c: float("-inf") for c in cont_cols} int_like = {c: True for c in cont_cols} max_decimals = {c: 0 for c in cont_cols} all_pos = {c: True for c in cont_cols} for i, row in enumerate(iter_rows(path)): for c in cont_cols: raw = row[c] if raw is None or raw == "": continue x = float(raw) if x <= 0: all_pos[c] = False if x < vmin[c]: vmin[c] = x if x > vmax[c]: vmax[c] = x if int_like[c] and abs(x - round(x)) > 1e-9: int_like[c] = False if "e" not in raw and "E" not in raw and "." in raw: dec = raw.split(".", 1)[1].rstrip("0") if len(dec) > max_decimals[c]: max_decimals[c] = len(dec) n = count[c] + 1 delta = x - mean[c] delta_n = delta / n term1 = delta * delta_n * (n - 1) m3[c] += term1 * delta_n * (n - 2) - 3 * delta_n * m2[c] m2[c] += term1 mean[c] += delta_n count[c] = n if max_rows is not None and i + 1 >= max_rows: break # finalize std/skew std = {} skew = {} for c in cont_cols: n = count[c] if n > 1: var = m2[c] / (n - 1) else: var = 0.0 std[c] = var ** 0.5 if var > 0 else 1.0 if n > 2 and m2[c] > 0: skew[c] = (math.sqrt(n) * (m3[c] / n)) / (m2[c] ** 1.5) else: skew[c] = 0.0 for c in cont_cols: if vmin[c] == float("inf"): vmin[c] = 0.0 if vmax[c] == float("-inf"): vmax[c] = 0.0 return { "count": count, "mean": mean, "std": std, "m2": m2, "m3": m3, "min": vmin, "max": vmax, "int_like": int_like, "max_decimals": max_decimals, "skew": skew, "all_pos": all_pos, } def choose_cont_transforms( path: Union[str, List[str]], cont_cols: List[str], max_rows: Optional[int] = None, skew_threshold: float = 1.5, range_ratio_threshold: float = 1e3, ): """Pick per-column transform (currently log1p or none) based on skew/range.""" stats = _stream_basic_stats(path, cont_cols, max_rows=max_rows) transforms = {} for c in cont_cols: if not stats["all_pos"][c]: transforms[c] = "none" continue skew = abs(stats["skew"][c]) vmin = stats["min"][c] vmax = stats["max"][c] ratio = (vmax / vmin) if vmin > 0 else 0.0 if skew >= skew_threshold or ratio >= range_ratio_threshold: transforms[c] = "log1p" else: transforms[c] = "none" return transforms, stats def compute_cont_stats( path: Union[str, List[str]], cont_cols: List[str], max_rows: Optional[int] = None, transforms: Optional[Dict[str, str]] = None, quantile_bins: Optional[int] = None, ): """Compute stats on (optionally transformed) values. Returns raw + transformed stats.""" # First pass (raw) for metadata and raw mean/std raw = _stream_basic_stats(path, cont_cols, max_rows=max_rows) # Optional transform selection if transforms is None: transforms = {c: "none" for c in cont_cols} # Second pass for transformed mean/std (and optional quantiles) count = {c: 0 for c in cont_cols} mean = {c: 0.0 for c in cont_cols} m2 = {c: 0.0 for c in cont_cols} quantile_values = {c: [] for c in cont_cols} if quantile_bins and quantile_bins > 1 else None for i, row in enumerate(iter_rows(path)): for c in cont_cols: raw_val = row[c] if raw_val is None or raw_val == "": continue x = float(raw_val) if transforms.get(c) == "log1p": if x < 0: x = 0.0 x = math.log1p(x) if quantile_values is not None: quantile_values[c].append(x) n = count[c] + 1 delta = x - mean[c] mean[c] += delta / n delta2 = x - mean[c] m2[c] += delta * delta2 count[c] = n if max_rows is not None and i + 1 >= max_rows: break std = {} for c in cont_cols: if count[c] > 1: var = m2[c] / (count[c] - 1) else: var = 0.0 std[c] = var ** 0.5 if var > 0 else 1.0 quantile_probs = None quantile_table = None if quantile_values is not None: quantile_probs = [i / (quantile_bins - 1) for i in range(quantile_bins)] quantile_table = {} for c in cont_cols: vals = quantile_values[c] if not vals: quantile_table[c] = [0.0 for _ in quantile_probs] continue vals.sort() n = len(vals) qvals = [] for p in quantile_probs: idx = int(round(p * (n - 1))) idx = max(0, min(n - 1, idx)) qvals.append(float(vals[idx])) quantile_table[c] = qvals return { "mean": mean, "std": std, "raw_mean": raw["mean"], "raw_std": raw["std"], "min": raw["min"], "max": raw["max"], "int_like": raw["int_like"], "max_decimals": raw["max_decimals"], "transform": transforms, "skew": raw["skew"], "all_pos": raw["all_pos"], "max_rows": max_rows, "quantile_probs": quantile_probs, "quantile_values": quantile_table, } def build_vocab( path: Union[str, List[str]], disc_cols: List[str], max_rows: Optional[int] = None, ) -> Dict[str, Dict[str, int]]: values = {c: set() for c in disc_cols} for i, row in enumerate(iter_rows(path)): for c in disc_cols: values[c].add(row[c]) if max_rows is not None and i + 1 >= max_rows: break vocab = {} for c in disc_cols: tokens = sorted(values[c]) if "" not in tokens: tokens.append("") vocab[c] = {tok: idx for idx, tok in enumerate(tokens)} return vocab def build_disc_stats( path: Union[str, List[str]], disc_cols: List[str], max_rows: Optional[int] = None, ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, str]]: counts = {c: {} for c in disc_cols} for i, row in enumerate(iter_rows(path)): for c in disc_cols: val = row[c] counts[c][val] = counts[c].get(val, 0) + 1 if max_rows is not None and i + 1 >= max_rows: break vocab = {} top_token = {} for c in disc_cols: tokens = sorted(counts[c].keys()) if "" not in tokens: tokens.append("") vocab[c] = {tok: idx for idx, tok in enumerate(tokens)} # most frequent token if counts[c]: top_token[c] = max(counts[c].items(), key=lambda kv: kv[1])[0] else: top_token[c] = "" return vocab, top_token def normalize_cont( x, cont_cols: List[str], mean: Dict[str, float], std: Dict[str, float], transforms: Optional[Dict[str, str]] = None, quantile_probs: Optional[List[float]] = None, quantile_values: Optional[Dict[str, List[float]]] = None, use_quantile: bool = False, ): import torch if transforms: for i, c in enumerate(cont_cols): if transforms.get(c) == "log1p": x[:, :, i] = torch.log1p(torch.clamp(x[:, :, i], min=0)) if use_quantile: if not quantile_probs or not quantile_values: raise ValueError("use_quantile_transform enabled but quantile stats missing") x = apply_quantile_transform(x, cont_cols, quantile_probs, quantile_values) mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device) std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device) return (x - mean_t) / std_t def _normal_cdf(x): import torch return 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) def _normal_ppf(p): import torch eps = 1e-6 p = torch.clamp(p, eps, 1.0 - eps) return math.sqrt(2.0) * torch.erfinv(2.0 * p - 1.0) def apply_quantile_transform(x, cont_cols, quantile_probs, quantile_values): import torch probs_t = torch.tensor(quantile_probs, dtype=x.dtype, device=x.device) for i, c in enumerate(cont_cols): q_vals = torch.tensor(quantile_values[c], dtype=x.dtype, device=x.device) v = x[:, :, i] idx = torch.bucketize(v, q_vals) idx = torch.clamp(idx, 1, q_vals.numel() - 1) x0 = q_vals[idx - 1] x1 = q_vals[idx] p0 = probs_t[idx - 1] p1 = probs_t[idx] denom = torch.where((x1 - x0) == 0, torch.ones_like(x1 - x0), (x1 - x0)) p = p0 + (v - x0) * (p1 - p0) / denom x[:, :, i] = _normal_ppf(p) return x def inverse_quantile_transform(x, cont_cols, quantile_probs, quantile_values): import torch probs_t = torch.tensor(quantile_probs, dtype=x.dtype, device=x.device) for i, c in enumerate(cont_cols): q_vals = torch.tensor(quantile_values[c], dtype=x.dtype, device=x.device) z = x[:, :, i] p = _normal_cdf(z) idx = torch.bucketize(p, probs_t) idx = torch.clamp(idx, 1, probs_t.numel() - 1) p0 = probs_t[idx - 1] p1 = probs_t[idx] x0 = q_vals[idx - 1] x1 = q_vals[idx] denom = torch.where((p1 - p0) == 0, torch.ones_like(p1 - p0), (p1 - p0)) v = x0 + (p - p0) * (x1 - x0) / denom x[:, :, i] = v return x def windowed_batches( path: Union[str, List[str]], cont_cols: List[str], disc_cols: List[str], vocab: Dict[str, Dict[str, int]], mean: Dict[str, float], std: Dict[str, float], batch_size: int, seq_len: int, max_batches: Optional[int] = None, return_file_id: bool = False, transforms: Optional[Dict[str, str]] = None, quantile_probs: Optional[List[float]] = None, quantile_values: Optional[Dict[str, List[float]]] = None, use_quantile: bool = False, shuffle_buffer: int = 0, ): import torch batch_cont = [] batch_disc = [] batch_file = [] buffer = [] seq_cont = [] seq_disc = [] def flush_seq(file_id: int): nonlocal seq_cont, seq_disc, batch_cont, batch_disc, batch_file if len(seq_cont) == seq_len: if shuffle_buffer and shuffle_buffer > 0: buffer.append((list(seq_cont), list(seq_disc), file_id)) if len(buffer) >= shuffle_buffer: idx = random.randrange(len(buffer)) seq_c, seq_d, seq_f = buffer.pop(idx) batch_cont.append(seq_c) batch_disc.append(seq_d) if return_file_id: batch_file.append(seq_f) else: batch_cont.append(seq_cont) batch_disc.append(seq_disc) if return_file_id: batch_file.append(file_id) seq_cont = [] seq_disc = [] batches_yielded = 0 paths = [path] if isinstance(path, str) else list(path) for file_id, p in enumerate(paths): for row in iter_rows(p): cont_row = [float(row[c]) for c in cont_cols] disc_row = [vocab[c].get(row[c], vocab[c][""]) for c in disc_cols] seq_cont.append(cont_row) seq_disc.append(disc_row) if len(seq_cont) == seq_len: flush_seq(file_id) if len(batch_cont) == batch_size: x_cont = torch.tensor(batch_cont, dtype=torch.float32) x_disc = torch.tensor(batch_disc, dtype=torch.long) x_cont = normalize_cont( x_cont, cont_cols, mean, std, transforms=transforms, quantile_probs=quantile_probs, quantile_values=quantile_values, use_quantile=use_quantile, ) if return_file_id: x_file = torch.tensor(batch_file, dtype=torch.long) yield x_cont, x_disc, x_file else: yield x_cont, x_disc batch_cont = [] batch_disc = [] batch_file = [] batches_yielded += 1 if max_batches is not None and batches_yielded >= max_batches: return # drop partial sequence at file boundary seq_cont = [] seq_disc = [] # Flush any remaining buffered sequences if shuffle_buffer and buffer: random.shuffle(buffer) for seq_c, seq_d, seq_f in buffer: batch_cont.append(seq_c) batch_disc.append(seq_d) if return_file_id: batch_file.append(seq_f) if len(batch_cont) == batch_size: import torch x_cont = torch.tensor(batch_cont, dtype=torch.float32) x_disc = torch.tensor(batch_disc, dtype=torch.long) x_cont = normalize_cont( x_cont, cont_cols, mean, std, transforms=transforms, quantile_probs=quantile_probs, quantile_values=quantile_values, use_quantile=use_quantile, ) if return_file_id: x_file = torch.tensor(batch_file, dtype=torch.long) yield x_cont, x_disc, x_file else: yield x_cont, x_disc batch_cont = [] batch_disc = [] batch_file = [] batches_yielded += 1 if max_batches is not None and batches_yielded >= max_batches: return # Drop last partial batch for simplicity