127 lines
3.6 KiB
Python
Executable File
127 lines
3.6 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Small utilities for HAI 21.03 data loading and feature encoding."""
|
|
|
|
import csv
|
|
import gzip
|
|
import json
|
|
from typing import Dict, Iterable, List, Optional, Tuple
|
|
|
|
|
|
|
|
def load_split(path: str) -> Dict[str, List[str]]:
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
|
|
|
|
def iter_rows(path: str) -> Iterable[Dict[str, str]]:
|
|
with gzip.open(path, "rt", newline="") as f:
|
|
reader = csv.DictReader(f)
|
|
for row in reader:
|
|
yield row
|
|
|
|
|
|
def compute_cont_stats(
|
|
path: str,
|
|
cont_cols: List[str],
|
|
max_rows: Optional[int] = None,
|
|
) -> Tuple[Dict[str, float], Dict[str, float]]:
|
|
"""Streaming mean/std (Welford)."""
|
|
count = 0
|
|
mean = {c: 0.0 for c in cont_cols}
|
|
m2 = {c: 0.0 for c in cont_cols}
|
|
|
|
for i, row in enumerate(iter_rows(path)):
|
|
count += 1
|
|
for c in cont_cols:
|
|
x = float(row[c])
|
|
delta = x - mean[c]
|
|
mean[c] += delta / count
|
|
delta2 = x - mean[c]
|
|
m2[c] += delta * delta2
|
|
if max_rows is not None and i + 1 >= max_rows:
|
|
break
|
|
|
|
std = {}
|
|
for c in cont_cols:
|
|
if count > 1:
|
|
var = m2[c] / (count - 1)
|
|
else:
|
|
var = 0.0
|
|
std[c] = var ** 0.5 if var > 0 else 1.0
|
|
return mean, std
|
|
|
|
|
|
def build_vocab(
|
|
path: str,
|
|
disc_cols: List[str],
|
|
max_rows: Optional[int] = None,
|
|
) -> Dict[str, Dict[str, int]]:
|
|
values = {c: set() for c in disc_cols}
|
|
for i, row in enumerate(iter_rows(path)):
|
|
for c in disc_cols:
|
|
values[c].add(row[c])
|
|
if max_rows is not None and i + 1 >= max_rows:
|
|
break
|
|
|
|
vocab = {}
|
|
for c in disc_cols:
|
|
tokens = sorted(values[c])
|
|
if "<UNK>" not in tokens:
|
|
tokens.append("<UNK>")
|
|
vocab[c] = {tok: idx for idx, tok in enumerate(tokens)}
|
|
return vocab
|
|
|
|
|
|
def normalize_cont(x, cont_cols: List[str], mean: Dict[str, float], std: Dict[str, float]):
|
|
import torch
|
|
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)
|
|
std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device)
|
|
return (x - mean_t) / std_t
|
|
|
|
|
|
def windowed_batches(
|
|
path: str,
|
|
cont_cols: List[str],
|
|
disc_cols: List[str],
|
|
vocab: Dict[str, Dict[str, int]],
|
|
mean: Dict[str, float],
|
|
std: Dict[str, float],
|
|
batch_size: int,
|
|
seq_len: int,
|
|
max_batches: Optional[int] = None,
|
|
):
|
|
import torch
|
|
batch_cont = []
|
|
batch_disc = []
|
|
seq_cont = []
|
|
seq_disc = []
|
|
|
|
def flush_seq():
|
|
nonlocal seq_cont, seq_disc, batch_cont, batch_disc
|
|
if len(seq_cont) == seq_len:
|
|
batch_cont.append(seq_cont)
|
|
batch_disc.append(seq_disc)
|
|
seq_cont = []
|
|
seq_disc = []
|
|
|
|
batches_yielded = 0
|
|
for row in iter_rows(path):
|
|
cont_row = [float(row[c]) for c in cont_cols]
|
|
disc_row = [vocab[c].get(row[c], vocab[c]["<UNK>"]) for c in disc_cols]
|
|
seq_cont.append(cont_row)
|
|
seq_disc.append(disc_row)
|
|
if len(seq_cont) == seq_len:
|
|
flush_seq()
|
|
if len(batch_cont) == batch_size:
|
|
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
|
|
x_disc = torch.tensor(batch_disc, dtype=torch.long)
|
|
x_cont = normalize_cont(x_cont, cont_cols, mean, std)
|
|
yield x_cont, x_disc
|
|
batch_cont = []
|
|
batch_disc = []
|
|
batches_yielded += 1
|
|
if max_batches is not None and batches_yielded >= max_batches:
|
|
return
|
|
|
|
# Drop last partial batch for simplicity
|