#!/usr/bin/env python3 """Prepare vocab and normalization stats for HAI 21.03.""" import json from typing import Optional from data_utils import compute_cont_stats, build_vocab, load_split DATA_PATH = "/home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz" SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json" OUT_STATS = "/home/anay/Dev/diffusion/mask-ddpm/example/results/cont_stats.json" OUT_VOCAB = "/home/anay/Dev/diffusion/mask-ddpm/example/results/disc_vocab.json" def main(max_rows: Optional[int] = None): split = load_split(SPLIT_PATH) cont_cols = split["continuous"] disc_cols = split["discrete"] mean, std = compute_cont_stats(DATA_PATH, cont_cols, max_rows=max_rows) vocab = build_vocab(DATA_PATH, disc_cols, max_rows=max_rows) with open(OUT_STATS, "w", encoding="ascii") as f: json.dump({"mean": mean, "std": std, "max_rows": max_rows}, f, indent=2) with open(OUT_VOCAB, "w", encoding="ascii") as f: json.dump({"vocab": vocab, "max_rows": max_rows}, f, indent=2) if __name__ == "__main__": # Default: sample 50000 rows for speed. Set to None for full scan. main(max_rows=50000)