#!/usr/bin/env python3 """Small utilities for HAI 21.03 data loading and feature encoding.""" import csv import gzip import json from typing import Dict, Iterable, List, Optional, Tuple, Union def load_split(path: str) -> Dict[str, List[str]]: with open(path, "r", encoding="utf-8") as f: return json.load(f) def iter_rows(path_or_paths: Union[str, List[str]]) -> Iterable[Dict[str, str]]: paths = [path_or_paths] if isinstance(path_or_paths, str) else list(path_or_paths) for path in paths: opener = gzip.open if str(path).endswith(".gz") else open with opener(path, "rt", newline="") as f: reader = csv.DictReader(f) for row in reader: yield row def compute_cont_stats( path: Union[str, List[str]], cont_cols: List[str], max_rows: Optional[int] = None, ) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, bool], Dict[str, int]]: """Streaming mean/std (Welford) + min/max + int/precision metadata.""" count = 0 mean = {c: 0.0 for c in cont_cols} m2 = {c: 0.0 for c in cont_cols} vmin = {c: float("inf") for c in cont_cols} vmax = {c: float("-inf") for c in cont_cols} int_like = {c: True for c in cont_cols} max_decimals = {c: 0 for c in cont_cols} for i, row in enumerate(iter_rows(path)): count += 1 for c in cont_cols: raw = row[c] if raw is None or raw == "": continue x = float(raw) delta = x - mean[c] mean[c] += delta / count delta2 = x - mean[c] m2[c] += delta * delta2 if x < vmin[c]: vmin[c] = x if x > vmax[c]: vmax[c] = x if int_like[c] and abs(x - round(x)) > 1e-9: int_like[c] = False # track decimal places from raw string if possible if "e" in raw or "E" in raw: # scientific notation, skip precision inference continue if "." in raw: dec = raw.split(".", 1)[1].rstrip("0") if len(dec) > max_decimals[c]: max_decimals[c] = len(dec) if max_rows is not None and i + 1 >= max_rows: break std = {} for c in cont_cols: if count > 1: var = m2[c] / (count - 1) else: var = 0.0 std[c] = var ** 0.5 if var > 0 else 1.0 # replace infs if column had no valid values for c in cont_cols: if vmin[c] == float("inf"): vmin[c] = 0.0 if vmax[c] == float("-inf"): vmax[c] = 0.0 return mean, std, vmin, vmax, int_like, max_decimals def build_vocab( path: Union[str, List[str]], disc_cols: List[str], max_rows: Optional[int] = None, ) -> Dict[str, Dict[str, int]]: values = {c: set() for c in disc_cols} for i, row in enumerate(iter_rows(path)): for c in disc_cols: values[c].add(row[c]) if max_rows is not None and i + 1 >= max_rows: break vocab = {} for c in disc_cols: tokens = sorted(values[c]) if "" not in tokens: tokens.append("") vocab[c] = {tok: idx for idx, tok in enumerate(tokens)} return vocab def build_disc_stats( path: Union[str, List[str]], disc_cols: List[str], max_rows: Optional[int] = None, ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, str]]: counts = {c: {} for c in disc_cols} for i, row in enumerate(iter_rows(path)): for c in disc_cols: val = row[c] counts[c][val] = counts[c].get(val, 0) + 1 if max_rows is not None and i + 1 >= max_rows: break vocab = {} top_token = {} for c in disc_cols: tokens = sorted(counts[c].keys()) if "" not in tokens: tokens.append("") vocab[c] = {tok: idx for idx, tok in enumerate(tokens)} # most frequent token if counts[c]: top_token[c] = max(counts[c].items(), key=lambda kv: kv[1])[0] else: top_token[c] = "" return vocab, top_token def normalize_cont(x, cont_cols: List[str], mean: Dict[str, float], std: Dict[str, float]): import torch mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device) std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device) return (x - mean_t) / std_t def windowed_batches( path: Union[str, List[str]], cont_cols: List[str], disc_cols: List[str], vocab: Dict[str, Dict[str, int]], mean: Dict[str, float], std: Dict[str, float], batch_size: int, seq_len: int, max_batches: Optional[int] = None, return_file_id: bool = False, ): import torch batch_cont = [] batch_disc = [] batch_file = [] seq_cont = [] seq_disc = [] def flush_seq(): nonlocal seq_cont, seq_disc, batch_cont, batch_disc if len(seq_cont) == seq_len: batch_cont.append(seq_cont) batch_disc.append(seq_disc) seq_cont = [] seq_disc = [] batches_yielded = 0 paths = [path] if isinstance(path, str) else list(path) for file_id, p in enumerate(paths): for row in iter_rows(p): cont_row = [float(row[c]) for c in cont_cols] disc_row = [vocab[c].get(row[c], vocab[c][""]) for c in disc_cols] seq_cont.append(cont_row) seq_disc.append(disc_row) if len(seq_cont) == seq_len: flush_seq() if return_file_id: batch_file.append(file_id) if len(batch_cont) == batch_size: x_cont = torch.tensor(batch_cont, dtype=torch.float32) x_disc = torch.tensor(batch_disc, dtype=torch.long) x_cont = normalize_cont(x_cont, cont_cols, mean, std) if return_file_id: x_file = torch.tensor(batch_file, dtype=torch.long) yield x_cont, x_disc, x_file else: yield x_cont, x_disc batch_cont = [] batch_disc = [] batch_file = [] batches_yielded += 1 if max_batches is not None and batches_yielded >= max_batches: return # drop partial sequence at file boundary seq_cont = [] seq_disc = [] # Drop last partial batch for simplicity