Files
mask-ddpm/example/data_utils.py

465 lines
15 KiB
Python
Executable File

#!/usr/bin/env python3
"""Small utilities for HAI 21.03 data loading and feature encoding."""
import csv
import gzip
import json
import math
import random
from typing import Dict, Iterable, List, Optional, Tuple, Union
def load_split(path: str) -> Dict[str, List[str]]:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def iter_rows(path_or_paths: Union[str, List[str]]) -> Iterable[Dict[str, str]]:
paths = [path_or_paths] if isinstance(path_or_paths, str) else list(path_or_paths)
for path in paths:
opener = gzip.open if str(path).endswith(".gz") else open
with opener(path, "rt", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
yield row
def _stream_basic_stats(
path: Union[str, List[str]],
cont_cols: List[str],
max_rows: Optional[int] = None,
):
"""Streaming stats with mean/M2/M3 + min/max + int/precision metadata."""
count = {c: 0 for c in cont_cols}
mean = {c: 0.0 for c in cont_cols}
m2 = {c: 0.0 for c in cont_cols}
m3 = {c: 0.0 for c in cont_cols}
vmin = {c: float("inf") for c in cont_cols}
vmax = {c: float("-inf") for c in cont_cols}
int_like = {c: True for c in cont_cols}
max_decimals = {c: 0 for c in cont_cols}
all_pos = {c: True for c in cont_cols}
for i, row in enumerate(iter_rows(path)):
for c in cont_cols:
raw = row[c]
if raw is None or raw == "":
continue
x = float(raw)
if x <= 0:
all_pos[c] = False
if x < vmin[c]:
vmin[c] = x
if x > vmax[c]:
vmax[c] = x
if int_like[c] and abs(x - round(x)) > 1e-9:
int_like[c] = False
if "e" not in raw and "E" not in raw and "." in raw:
dec = raw.split(".", 1)[1].rstrip("0")
if len(dec) > max_decimals[c]:
max_decimals[c] = len(dec)
n = count[c] + 1
delta = x - mean[c]
delta_n = delta / n
term1 = delta * delta_n * (n - 1)
m3[c] += term1 * delta_n * (n - 2) - 3 * delta_n * m2[c]
m2[c] += term1
mean[c] += delta_n
count[c] = n
if max_rows is not None and i + 1 >= max_rows:
break
# finalize std/skew
std = {}
skew = {}
for c in cont_cols:
n = count[c]
if n > 1:
var = m2[c] / (n - 1)
else:
var = 0.0
std[c] = var ** 0.5 if var > 0 else 1.0
if n > 2 and m2[c] > 0:
skew[c] = (math.sqrt(n) * (m3[c] / n)) / (m2[c] ** 1.5)
else:
skew[c] = 0.0
for c in cont_cols:
if vmin[c] == float("inf"):
vmin[c] = 0.0
if vmax[c] == float("-inf"):
vmax[c] = 0.0
return {
"count": count,
"mean": mean,
"std": std,
"m2": m2,
"m3": m3,
"min": vmin,
"max": vmax,
"int_like": int_like,
"max_decimals": max_decimals,
"skew": skew,
"all_pos": all_pos,
}
def choose_cont_transforms(
path: Union[str, List[str]],
cont_cols: List[str],
max_rows: Optional[int] = None,
skew_threshold: float = 1.5,
range_ratio_threshold: float = 1e3,
):
"""Pick per-column transform (currently log1p or none) based on skew/range."""
stats = _stream_basic_stats(path, cont_cols, max_rows=max_rows)
transforms = {}
for c in cont_cols:
if not stats["all_pos"][c]:
transforms[c] = "none"
continue
skew = abs(stats["skew"][c])
vmin = stats["min"][c]
vmax = stats["max"][c]
ratio = (vmax / vmin) if vmin > 0 else 0.0
if skew >= skew_threshold or ratio >= range_ratio_threshold:
transforms[c] = "log1p"
else:
transforms[c] = "none"
return transforms, stats
def compute_cont_stats(
path: Union[str, List[str]],
cont_cols: List[str],
max_rows: Optional[int] = None,
transforms: Optional[Dict[str, str]] = None,
quantile_bins: Optional[int] = None,
):
"""Compute stats on (optionally transformed) values. Returns raw + transformed stats."""
# First pass (raw) for metadata and raw mean/std
raw = _stream_basic_stats(path, cont_cols, max_rows=max_rows)
# Optional transform selection
if transforms is None:
transforms = {c: "none" for c in cont_cols}
# Second pass for transformed mean/std (and optional quantiles)
count = {c: 0 for c in cont_cols}
mean = {c: 0.0 for c in cont_cols}
m2 = {c: 0.0 for c in cont_cols}
quantile_values = {c: [] for c in cont_cols} if quantile_bins and quantile_bins > 1 else None
for i, row in enumerate(iter_rows(path)):
for c in cont_cols:
raw_val = row[c]
if raw_val is None or raw_val == "":
continue
x = float(raw_val)
if transforms.get(c) == "log1p":
if x < 0:
x = 0.0
x = math.log1p(x)
if quantile_values is not None:
quantile_values[c].append(x)
n = count[c] + 1
delta = x - mean[c]
mean[c] += delta / n
delta2 = x - mean[c]
m2[c] += delta * delta2
count[c] = n
if max_rows is not None and i + 1 >= max_rows:
break
std = {}
for c in cont_cols:
if count[c] > 1:
var = m2[c] / (count[c] - 1)
else:
var = 0.0
std[c] = var ** 0.5 if var > 0 else 1.0
quantile_probs = None
quantile_table = None
if quantile_values is not None:
quantile_probs = [i / (quantile_bins - 1) for i in range(quantile_bins)]
quantile_table = {}
for c in cont_cols:
vals = quantile_values[c]
if not vals:
quantile_table[c] = [0.0 for _ in quantile_probs]
continue
vals.sort()
n = len(vals)
qvals = []
for p in quantile_probs:
idx = int(round(p * (n - 1)))
idx = max(0, min(n - 1, idx))
qvals.append(float(vals[idx]))
quantile_table[c] = qvals
return {
"mean": mean,
"std": std,
"raw_mean": raw["mean"],
"raw_std": raw["std"],
"min": raw["min"],
"max": raw["max"],
"int_like": raw["int_like"],
"max_decimals": raw["max_decimals"],
"transform": transforms,
"skew": raw["skew"],
"all_pos": raw["all_pos"],
"max_rows": max_rows,
"quantile_probs": quantile_probs,
"quantile_values": quantile_table,
}
def build_vocab(
path: Union[str, List[str]],
disc_cols: List[str],
max_rows: Optional[int] = None,
) -> Dict[str, Dict[str, int]]:
values = {c: set() for c in disc_cols}
for i, row in enumerate(iter_rows(path)):
for c in disc_cols:
values[c].add(row[c])
if max_rows is not None and i + 1 >= max_rows:
break
vocab = {}
for c in disc_cols:
tokens = sorted(values[c])
if "<UNK>" not in tokens:
tokens.append("<UNK>")
vocab[c] = {tok: idx for idx, tok in enumerate(tokens)}
return vocab
def build_disc_stats(
path: Union[str, List[str]],
disc_cols: List[str],
max_rows: Optional[int] = None,
) -> Tuple[Dict[str, Dict[str, int]], Dict[str, str]]:
counts = {c: {} for c in disc_cols}
for i, row in enumerate(iter_rows(path)):
for c in disc_cols:
val = row[c]
counts[c][val] = counts[c].get(val, 0) + 1
if max_rows is not None and i + 1 >= max_rows:
break
vocab = {}
top_token = {}
for c in disc_cols:
tokens = sorted(counts[c].keys())
if "<UNK>" not in tokens:
tokens.append("<UNK>")
vocab[c] = {tok: idx for idx, tok in enumerate(tokens)}
# most frequent token
if counts[c]:
top_token[c] = max(counts[c].items(), key=lambda kv: kv[1])[0]
else:
top_token[c] = "<UNK>"
return vocab, top_token
def normalize_cont(
x,
cont_cols: List[str],
mean: Dict[str, float],
std: Dict[str, float],
transforms: Optional[Dict[str, str]] = None,
quantile_probs: Optional[List[float]] = None,
quantile_values: Optional[Dict[str, List[float]]] = None,
use_quantile: bool = False,
):
import torch
if transforms:
for i, c in enumerate(cont_cols):
if transforms.get(c) == "log1p":
x[:, :, i] = torch.log1p(torch.clamp(x[:, :, i], min=0))
if use_quantile:
if not quantile_probs or not quantile_values:
raise ValueError("use_quantile_transform enabled but quantile stats missing")
x = apply_quantile_transform(x, cont_cols, quantile_probs, quantile_values)
# quantile transform already targets N(0,1); skip extra standardization
return x
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)
std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device)
return (x - mean_t) / std_t
def _normal_cdf(x):
import torch
return 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def _normal_ppf(p):
import torch
eps = 1e-6
p = torch.clamp(p, eps, 1.0 - eps)
return math.sqrt(2.0) * torch.erfinv(2.0 * p - 1.0)
def apply_quantile_transform(x, cont_cols, quantile_probs, quantile_values):
import torch
probs_t = torch.tensor(quantile_probs, dtype=x.dtype, device=x.device)
for i, c in enumerate(cont_cols):
q_vals = torch.tensor(quantile_values[c], dtype=x.dtype, device=x.device)
v = x[:, :, i]
idx = torch.bucketize(v, q_vals)
idx = torch.clamp(idx, 1, q_vals.numel() - 1)
x0 = q_vals[idx - 1]
x1 = q_vals[idx]
p0 = probs_t[idx - 1]
p1 = probs_t[idx]
denom = torch.where((x1 - x0) == 0, torch.ones_like(x1 - x0), (x1 - x0))
p = p0 + (v - x0) * (p1 - p0) / denom
x[:, :, i] = _normal_ppf(p)
return x
def inverse_quantile_transform(x, cont_cols, quantile_probs, quantile_values):
import torch
probs_t = torch.tensor(quantile_probs, dtype=x.dtype, device=x.device)
for i, c in enumerate(cont_cols):
q_vals = torch.tensor(quantile_values[c], dtype=x.dtype, device=x.device)
z = x[:, :, i]
p = _normal_cdf(z)
idx = torch.bucketize(p, probs_t)
idx = torch.clamp(idx, 1, probs_t.numel() - 1)
p0 = probs_t[idx - 1]
p1 = probs_t[idx]
x0 = q_vals[idx - 1]
x1 = q_vals[idx]
denom = torch.where((p1 - p0) == 0, torch.ones_like(p1 - p0), (p1 - p0))
v = x0 + (p - p0) * (x1 - x0) / denom
x[:, :, i] = v
return x
def windowed_batches(
path: Union[str, List[str]],
cont_cols: List[str],
disc_cols: List[str],
vocab: Dict[str, Dict[str, int]],
mean: Dict[str, float],
std: Dict[str, float],
batch_size: int,
seq_len: int,
max_batches: Optional[int] = None,
return_file_id: bool = False,
transforms: Optional[Dict[str, str]] = None,
quantile_probs: Optional[List[float]] = None,
quantile_values: Optional[Dict[str, List[float]]] = None,
use_quantile: bool = False,
shuffle_buffer: int = 0,
):
import torch
batch_cont = []
batch_disc = []
batch_file = []
buffer = []
seq_cont = []
seq_disc = []
def flush_seq(file_id: int):
nonlocal seq_cont, seq_disc, batch_cont, batch_disc, batch_file
if len(seq_cont) == seq_len:
if shuffle_buffer and shuffle_buffer > 0:
buffer.append((list(seq_cont), list(seq_disc), file_id))
if len(buffer) >= shuffle_buffer:
idx = random.randrange(len(buffer))
seq_c, seq_d, seq_f = buffer.pop(idx)
batch_cont.append(seq_c)
batch_disc.append(seq_d)
if return_file_id:
batch_file.append(seq_f)
else:
batch_cont.append(seq_cont)
batch_disc.append(seq_disc)
if return_file_id:
batch_file.append(file_id)
seq_cont = []
seq_disc = []
batches_yielded = 0
paths = [path] if isinstance(path, str) else list(path)
for file_id, p in enumerate(paths):
for row in iter_rows(p):
cont_row = [float(row[c]) for c in cont_cols]
disc_row = [vocab[c].get(row[c], vocab[c]["<UNK>"]) for c in disc_cols]
seq_cont.append(cont_row)
seq_disc.append(disc_row)
if len(seq_cont) == seq_len:
flush_seq(file_id)
if len(batch_cont) == batch_size:
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
x_disc = torch.tensor(batch_disc, dtype=torch.long)
x_cont = normalize_cont(
x_cont,
cont_cols,
mean,
std,
transforms=transforms,
quantile_probs=quantile_probs,
quantile_values=quantile_values,
use_quantile=use_quantile,
)
if return_file_id:
x_file = torch.tensor(batch_file, dtype=torch.long)
yield x_cont, x_disc, x_file
else:
yield x_cont, x_disc
batch_cont = []
batch_disc = []
batch_file = []
batches_yielded += 1
if max_batches is not None and batches_yielded >= max_batches:
return
# drop partial sequence at file boundary
seq_cont = []
seq_disc = []
# Flush any remaining buffered sequences
if shuffle_buffer and buffer:
random.shuffle(buffer)
for seq_c, seq_d, seq_f in buffer:
batch_cont.append(seq_c)
batch_disc.append(seq_d)
if return_file_id:
batch_file.append(seq_f)
if len(batch_cont) == batch_size:
import torch
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
x_disc = torch.tensor(batch_disc, dtype=torch.long)
x_cont = normalize_cont(
x_cont,
cont_cols,
mean,
std,
transforms=transforms,
quantile_probs=quantile_probs,
quantile_values=quantile_values,
use_quantile=use_quantile,
)
if return_file_id:
x_file = torch.tensor(batch_file, dtype=torch.long)
yield x_cont, x_disc, x_file
else:
yield x_cont, x_disc
batch_cont = []
batch_disc = []
batch_file = []
batches_yielded += 1
if max_batches is not None and batches_yielded >= max_batches:
return
# Drop last partial batch for simplicity