Files
mask-ddpm/example/data_utils.py

361 lines
11 KiB
Python
Executable File

#!/usr/bin/env python3
"""Small utilities for HAI 21.03 data loading and feature encoding."""
import csv
import gzip
import json
import math
import random
from typing import Dict, Iterable, List, Optional, Tuple, Union
def load_split(path: str) -> Dict[str, List[str]]:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def iter_rows(path_or_paths: Union[str, List[str]]) -> Iterable[Dict[str, str]]:
paths = [path_or_paths] if isinstance(path_or_paths, str) else list(path_or_paths)
for path in paths:
opener = gzip.open if str(path).endswith(".gz") else open
with opener(path, "rt", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
yield row
def _stream_basic_stats(
path: Union[str, List[str]],
cont_cols: List[str],
max_rows: Optional[int] = None,
):
"""Streaming stats with mean/M2/M3 + min/max + int/precision metadata."""
count = {c: 0 for c in cont_cols}
mean = {c: 0.0 for c in cont_cols}
m2 = {c: 0.0 for c in cont_cols}
m3 = {c: 0.0 for c in cont_cols}
vmin = {c: float("inf") for c in cont_cols}
vmax = {c: float("-inf") for c in cont_cols}
int_like = {c: True for c in cont_cols}
max_decimals = {c: 0 for c in cont_cols}
all_pos = {c: True for c in cont_cols}
for i, row in enumerate(iter_rows(path)):
for c in cont_cols:
raw = row[c]
if raw is None or raw == "":
continue
x = float(raw)
if x <= 0:
all_pos[c] = False
if x < vmin[c]:
vmin[c] = x
if x > vmax[c]:
vmax[c] = x
if int_like[c] and abs(x - round(x)) > 1e-9:
int_like[c] = False
if "e" not in raw and "E" not in raw and "." in raw:
dec = raw.split(".", 1)[1].rstrip("0")
if len(dec) > max_decimals[c]:
max_decimals[c] = len(dec)
n = count[c] + 1
delta = x - mean[c]
delta_n = delta / n
term1 = delta * delta_n * (n - 1)
m3[c] += term1 * delta_n * (n - 2) - 3 * delta_n * m2[c]
m2[c] += term1
mean[c] += delta_n
count[c] = n
if max_rows is not None and i + 1 >= max_rows:
break
# finalize std/skew
std = {}
skew = {}
for c in cont_cols:
n = count[c]
if n > 1:
var = m2[c] / (n - 1)
else:
var = 0.0
std[c] = var ** 0.5 if var > 0 else 1.0
if n > 2 and m2[c] > 0:
skew[c] = (math.sqrt(n) * (m3[c] / n)) / (m2[c] ** 1.5)
else:
skew[c] = 0.0
for c in cont_cols:
if vmin[c] == float("inf"):
vmin[c] = 0.0
if vmax[c] == float("-inf"):
vmax[c] = 0.0
return {
"count": count,
"mean": mean,
"std": std,
"m2": m2,
"m3": m3,
"min": vmin,
"max": vmax,
"int_like": int_like,
"max_decimals": max_decimals,
"skew": skew,
"all_pos": all_pos,
}
def choose_cont_transforms(
path: Union[str, List[str]],
cont_cols: List[str],
max_rows: Optional[int] = None,
skew_threshold: float = 1.5,
range_ratio_threshold: float = 1e3,
):
"""Pick per-column transform (currently log1p or none) based on skew/range."""
stats = _stream_basic_stats(path, cont_cols, max_rows=max_rows)
transforms = {}
for c in cont_cols:
if not stats["all_pos"][c]:
transforms[c] = "none"
continue
skew = abs(stats["skew"][c])
vmin = stats["min"][c]
vmax = stats["max"][c]
ratio = (vmax / vmin) if vmin > 0 else 0.0
if skew >= skew_threshold or ratio >= range_ratio_threshold:
transforms[c] = "log1p"
else:
transforms[c] = "none"
return transforms, stats
def compute_cont_stats(
path: Union[str, List[str]],
cont_cols: List[str],
max_rows: Optional[int] = None,
transforms: Optional[Dict[str, str]] = None,
):
"""Compute stats on (optionally transformed) values. Returns raw + transformed stats."""
# First pass (raw) for metadata and raw mean/std
raw = _stream_basic_stats(path, cont_cols, max_rows=max_rows)
# Optional transform selection
if transforms is None:
transforms = {c: "none" for c in cont_cols}
# Second pass for transformed mean/std
count = {c: 0 for c in cont_cols}
mean = {c: 0.0 for c in cont_cols}
m2 = {c: 0.0 for c in cont_cols}
for i, row in enumerate(iter_rows(path)):
for c in cont_cols:
raw_val = row[c]
if raw_val is None or raw_val == "":
continue
x = float(raw_val)
if transforms.get(c) == "log1p":
if x < 0:
x = 0.0
x = math.log1p(x)
n = count[c] + 1
delta = x - mean[c]
mean[c] += delta / n
delta2 = x - mean[c]
m2[c] += delta * delta2
count[c] = n
if max_rows is not None and i + 1 >= max_rows:
break
std = {}
for c in cont_cols:
if count[c] > 1:
var = m2[c] / (count[c] - 1)
else:
var = 0.0
std[c] = var ** 0.5 if var > 0 else 1.0
return {
"mean": mean,
"std": std,
"raw_mean": raw["mean"],
"raw_std": raw["std"],
"min": raw["min"],
"max": raw["max"],
"int_like": raw["int_like"],
"max_decimals": raw["max_decimals"],
"transform": transforms,
"skew": raw["skew"],
"all_pos": raw["all_pos"],
"max_rows": max_rows,
}
def build_vocab(
path: Union[str, List[str]],
disc_cols: List[str],
max_rows: Optional[int] = None,
) -> Dict[str, Dict[str, int]]:
values = {c: set() for c in disc_cols}
for i, row in enumerate(iter_rows(path)):
for c in disc_cols:
values[c].add(row[c])
if max_rows is not None and i + 1 >= max_rows:
break
vocab = {}
for c in disc_cols:
tokens = sorted(values[c])
if "<UNK>" not in tokens:
tokens.append("<UNK>")
vocab[c] = {tok: idx for idx, tok in enumerate(tokens)}
return vocab
def build_disc_stats(
path: Union[str, List[str]],
disc_cols: List[str],
max_rows: Optional[int] = None,
) -> Tuple[Dict[str, Dict[str, int]], Dict[str, str]]:
counts = {c: {} for c in disc_cols}
for i, row in enumerate(iter_rows(path)):
for c in disc_cols:
val = row[c]
counts[c][val] = counts[c].get(val, 0) + 1
if max_rows is not None and i + 1 >= max_rows:
break
vocab = {}
top_token = {}
for c in disc_cols:
tokens = sorted(counts[c].keys())
if "<UNK>" not in tokens:
tokens.append("<UNK>")
vocab[c] = {tok: idx for idx, tok in enumerate(tokens)}
# most frequent token
if counts[c]:
top_token[c] = max(counts[c].items(), key=lambda kv: kv[1])[0]
else:
top_token[c] = "<UNK>"
return vocab, top_token
def normalize_cont(
x,
cont_cols: List[str],
mean: Dict[str, float],
std: Dict[str, float],
transforms: Optional[Dict[str, str]] = None,
):
import torch
if transforms:
for i, c in enumerate(cont_cols):
if transforms.get(c) == "log1p":
x[:, :, i] = torch.log1p(torch.clamp(x[:, :, i], min=0))
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)
std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device)
return (x - mean_t) / std_t
def windowed_batches(
path: Union[str, List[str]],
cont_cols: List[str],
disc_cols: List[str],
vocab: Dict[str, Dict[str, int]],
mean: Dict[str, float],
std: Dict[str, float],
batch_size: int,
seq_len: int,
max_batches: Optional[int] = None,
return_file_id: bool = False,
transforms: Optional[Dict[str, str]] = None,
shuffle_buffer: int = 0,
):
import torch
batch_cont = []
batch_disc = []
batch_file = []
buffer = []
seq_cont = []
seq_disc = []
def flush_seq(file_id: int):
nonlocal seq_cont, seq_disc, batch_cont, batch_disc, batch_file
if len(seq_cont) == seq_len:
if shuffle_buffer and shuffle_buffer > 0:
buffer.append((list(seq_cont), list(seq_disc), file_id))
if len(buffer) >= shuffle_buffer:
idx = random.randrange(len(buffer))
seq_c, seq_d, seq_f = buffer.pop(idx)
batch_cont.append(seq_c)
batch_disc.append(seq_d)
if return_file_id:
batch_file.append(seq_f)
else:
batch_cont.append(seq_cont)
batch_disc.append(seq_disc)
if return_file_id:
batch_file.append(file_id)
seq_cont = []
seq_disc = []
batches_yielded = 0
paths = [path] if isinstance(path, str) else list(path)
for file_id, p in enumerate(paths):
for row in iter_rows(p):
cont_row = [float(row[c]) for c in cont_cols]
disc_row = [vocab[c].get(row[c], vocab[c]["<UNK>"]) for c in disc_cols]
seq_cont.append(cont_row)
seq_disc.append(disc_row)
if len(seq_cont) == seq_len:
flush_seq(file_id)
if len(batch_cont) == batch_size:
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
x_disc = torch.tensor(batch_disc, dtype=torch.long)
x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms)
if return_file_id:
x_file = torch.tensor(batch_file, dtype=torch.long)
yield x_cont, x_disc, x_file
else:
yield x_cont, x_disc
batch_cont = []
batch_disc = []
batch_file = []
batches_yielded += 1
if max_batches is not None and batches_yielded >= max_batches:
return
# drop partial sequence at file boundary
seq_cont = []
seq_disc = []
# Flush any remaining buffered sequences
if shuffle_buffer and buffer:
random.shuffle(buffer)
for seq_c, seq_d, seq_f in buffer:
batch_cont.append(seq_c)
batch_disc.append(seq_d)
if return_file_id:
batch_file.append(seq_f)
if len(batch_cont) == batch_size:
import torch
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
x_disc = torch.tensor(batch_disc, dtype=torch.long)
x_cont = normalize_cont(x_cont, cont_cols, mean, std, transforms=transforms)
if return_file_id:
x_file = torch.tensor(batch_file, dtype=torch.long)
yield x_cont, x_disc, x_file
else:
yield x_cont, x_disc
batch_cont = []
batch_disc = []
batch_file = []
batches_yielded += 1
if max_batches is not None and batches_yielded >= max_batches:
return
# Drop last partial batch for simplicity