Clean artifacts and update example pipeline

This commit is contained in:
2026-01-22 16:32:51 +08:00
parent c0639386be
commit c3f750cd9d
20 changed files with 651 additions and 30826 deletions

View File

@@ -2,20 +2,24 @@
"""Prepare vocab and normalization stats for HAI 21.03."""
import json
from pathlib import Path
from typing import Optional
from data_utils import compute_cont_stats, build_vocab, load_split
DATA_PATH = "/home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz"
SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json"
OUT_STATS = "/home/anay/Dev/diffusion/mask-ddpm/example/results/cont_stats.json"
OUT_VOCAB = "/home/anay/Dev/diffusion/mask-ddpm/example/results/disc_vocab.json"
BASE_DIR = Path(__file__).resolve().parent
REPO_DIR = BASE_DIR.parent.parent
DATA_PATH = str(REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz")
SPLIT_PATH = str(BASE_DIR / "feature_split.json")
OUT_STATS = str(BASE_DIR / "results" / "cont_stats.json")
OUT_VOCAB = str(BASE_DIR / "results" / "disc_vocab.json")
def main(max_rows: Optional[int] = None):
split = load_split(SPLIT_PATH)
cont_cols = split["continuous"]
disc_cols = split["discrete"]
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
mean, std = compute_cont_stats(DATA_PATH, cont_cols, max_rows=max_rows)
vocab = build_vocab(DATA_PATH, disc_cols, max_rows=max_rows)