Files
mask-ddpm/example/data_utils.py
2026-01-22 21:17:11 +08:00

199 lines
6.5 KiB
Python
Executable File

#!/usr/bin/env python3
"""Small utilities for HAI 21.03 data loading and feature encoding."""
import csv
import gzip
import json
from typing import Dict, Iterable, List, Optional, Tuple, Union
def load_split(path: str) -> Dict[str, List[str]]:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def iter_rows(path_or_paths: Union[str, List[str]]) -> Iterable[Dict[str, str]]:
paths = [path_or_paths] if isinstance(path_or_paths, str) else list(path_or_paths)
for path in paths:
opener = gzip.open if str(path).endswith(".gz") else open
with opener(path, "rt", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
yield row
def compute_cont_stats(
path: Union[str, List[str]],
cont_cols: List[str],
max_rows: Optional[int] = None,
) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, bool], Dict[str, int]]:
"""Streaming mean/std (Welford) + min/max + int/precision metadata."""
count = 0
mean = {c: 0.0 for c in cont_cols}
m2 = {c: 0.0 for c in cont_cols}
vmin = {c: float("inf") for c in cont_cols}
vmax = {c: float("-inf") for c in cont_cols}
int_like = {c: True for c in cont_cols}
max_decimals = {c: 0 for c in cont_cols}
for i, row in enumerate(iter_rows(path)):
count += 1
for c in cont_cols:
raw = row[c]
if raw is None or raw == "":
continue
x = float(raw)
delta = x - mean[c]
mean[c] += delta / count
delta2 = x - mean[c]
m2[c] += delta * delta2
if x < vmin[c]:
vmin[c] = x
if x > vmax[c]:
vmax[c] = x
if int_like[c] and abs(x - round(x)) > 1e-9:
int_like[c] = False
# track decimal places from raw string if possible
if "e" in raw or "E" in raw:
# scientific notation, skip precision inference
continue
if "." in raw:
dec = raw.split(".", 1)[1].rstrip("0")
if len(dec) > max_decimals[c]:
max_decimals[c] = len(dec)
if max_rows is not None and i + 1 >= max_rows:
break
std = {}
for c in cont_cols:
if count > 1:
var = m2[c] / (count - 1)
else:
var = 0.0
std[c] = var ** 0.5 if var > 0 else 1.0
# replace infs if column had no valid values
for c in cont_cols:
if vmin[c] == float("inf"):
vmin[c] = 0.0
if vmax[c] == float("-inf"):
vmax[c] = 0.0
return mean, std, vmin, vmax, int_like, max_decimals
def build_vocab(
path: Union[str, List[str]],
disc_cols: List[str],
max_rows: Optional[int] = None,
) -> Dict[str, Dict[str, int]]:
values = {c: set() for c in disc_cols}
for i, row in enumerate(iter_rows(path)):
for c in disc_cols:
values[c].add(row[c])
if max_rows is not None and i + 1 >= max_rows:
break
vocab = {}
for c in disc_cols:
tokens = sorted(values[c])
if "<UNK>" not in tokens:
tokens.append("<UNK>")
vocab[c] = {tok: idx for idx, tok in enumerate(tokens)}
return vocab
def build_disc_stats(
path: Union[str, List[str]],
disc_cols: List[str],
max_rows: Optional[int] = None,
) -> Tuple[Dict[str, Dict[str, int]], Dict[str, str]]:
counts = {c: {} for c in disc_cols}
for i, row in enumerate(iter_rows(path)):
for c in disc_cols:
val = row[c]
counts[c][val] = counts[c].get(val, 0) + 1
if max_rows is not None and i + 1 >= max_rows:
break
vocab = {}
top_token = {}
for c in disc_cols:
tokens = sorted(counts[c].keys())
if "<UNK>" not in tokens:
tokens.append("<UNK>")
vocab[c] = {tok: idx for idx, tok in enumerate(tokens)}
# most frequent token
if counts[c]:
top_token[c] = max(counts[c].items(), key=lambda kv: kv[1])[0]
else:
top_token[c] = "<UNK>"
return vocab, top_token
def normalize_cont(x, cont_cols: List[str], mean: Dict[str, float], std: Dict[str, float]):
import torch
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)
std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device)
return (x - mean_t) / std_t
def windowed_batches(
path: Union[str, List[str]],
cont_cols: List[str],
disc_cols: List[str],
vocab: Dict[str, Dict[str, int]],
mean: Dict[str, float],
std: Dict[str, float],
batch_size: int,
seq_len: int,
max_batches: Optional[int] = None,
return_file_id: bool = False,
):
import torch
batch_cont = []
batch_disc = []
batch_file = []
seq_cont = []
seq_disc = []
def flush_seq():
nonlocal seq_cont, seq_disc, batch_cont, batch_disc
if len(seq_cont) == seq_len:
batch_cont.append(seq_cont)
batch_disc.append(seq_disc)
seq_cont = []
seq_disc = []
batches_yielded = 0
paths = [path] if isinstance(path, str) else list(path)
for file_id, p in enumerate(paths):
for row in iter_rows(p):
cont_row = [float(row[c]) for c in cont_cols]
disc_row = [vocab[c].get(row[c], vocab[c]["<UNK>"]) for c in disc_cols]
seq_cont.append(cont_row)
seq_disc.append(disc_row)
if len(seq_cont) == seq_len:
flush_seq()
if return_file_id:
batch_file.append(file_id)
if len(batch_cont) == batch_size:
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
x_disc = torch.tensor(batch_disc, dtype=torch.long)
x_cont = normalize_cont(x_cont, cont_cols, mean, std)
if return_file_id:
x_file = torch.tensor(batch_file, dtype=torch.long)
yield x_cont, x_disc, x_file
else:
yield x_cont, x_disc
batch_cont = []
batch_disc = []
batch_file = []
batches_yielded += 1
if max_batches is not None and batches_yielded >= max_batches:
return
# drop partial sequence at file boundary
seq_cont = []
seq_disc = []
# Drop last partial batch for simplicity