Update example and notes

This commit is contained in:
2026-01-09 02:14:20 +08:00
parent 200bdf6136
commit c0639386be
18 changed files with 31656 additions and 0 deletions

32
example/prepare_data.py Executable file
View File

@@ -0,0 +1,32 @@
#!/usr/bin/env python3
"""Prepare vocab and normalization stats for HAI 21.03."""
import json
from typing import Optional
from data_utils import compute_cont_stats, build_vocab, load_split
DATA_PATH = "/home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz"
SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json"
OUT_STATS = "/home/anay/Dev/diffusion/mask-ddpm/example/results/cont_stats.json"
OUT_VOCAB = "/home/anay/Dev/diffusion/mask-ddpm/example/results/disc_vocab.json"
def main(max_rows: Optional[int] = None):
split = load_split(SPLIT_PATH)
cont_cols = split["continuous"]
disc_cols = split["discrete"]
mean, std = compute_cont_stats(DATA_PATH, cont_cols, max_rows=max_rows)
vocab = build_vocab(DATA_PATH, disc_cols, max_rows=max_rows)
with open(OUT_STATS, "w", encoding="ascii") as f:
json.dump({"mean": mean, "std": std, "max_rows": max_rows}, f, indent=2)
with open(OUT_VOCAB, "w", encoding="ascii") as f:
json.dump({"vocab": vocab, "max_rows": max_rows}, f, indent=2)
if __name__ == "__main__":
# Default: sample 50000 rows for speed. Set to None for full scan.
main(max_rows=50000)