Files
mask-ddpm/example/prepare_data.py
2026-01-23 12:40:20 +08:00

80 lines
3.0 KiB
Python
Executable File

#!/usr/bin/env python3
"""Prepare vocab and normalization stats for HAI 21.03."""
import argparse
import json
from pathlib import Path
from typing import Optional
from data_utils import compute_cont_stats, build_disc_stats, load_split
from platform_utils import safe_path, ensure_dir, resolve_path
BASE_DIR = Path(__file__).resolve().parent
REPO_DIR = BASE_DIR.parent.parent
DATA_GLOB = REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"
SPLIT_PATH = BASE_DIR / "feature_split.json"
OUT_STATS = BASE_DIR / "results" / "cont_stats.json"
OUT_VOCAB = BASE_DIR / "results" / "disc_vocab.json"
def parse_args():
parser = argparse.ArgumentParser(description="Prepare vocab and stats for HAI.")
parser.add_argument("--data-glob", default=str(DATA_GLOB), help="Glob for train CSVs")
parser.add_argument("--split-path", default=str(SPLIT_PATH), help="Split JSON path")
parser.add_argument("--out-stats", default=str(OUT_STATS), help="Output stats JSON")
parser.add_argument("--out-vocab", default=str(OUT_VOCAB), help="Output vocab JSON")
parser.add_argument("--max-rows", type=int, default=50000, help="Row cap for speed")
return parser.parse_args()
def main(max_rows: Optional[int] = None, split_path: Optional[str] = None, data_glob: Optional[str] = None,
out_stats: Optional[str] = None, out_vocab: Optional[str] = None):
split_path = split_path or str(SPLIT_PATH)
data_glob = data_glob or str(DATA_GLOB)
out_stats = out_stats or str(OUT_STATS)
out_vocab = out_vocab or str(OUT_VOCAB)
split = load_split(safe_path(split_path))
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
glob_path = resolve_path(BASE_DIR, data_glob)
data_paths = sorted(Path(glob_path).parent.glob(Path(glob_path).name))
if not data_paths:
raise SystemExit("no train files found under %s" % str(glob_path))
data_paths = [safe_path(p) for p in data_paths]
mean, std, vmin, vmax, int_like, max_decimals = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows)
vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows)
ensure_dir(Path(out_stats).parent)
with open(safe_path(out_stats), "w", encoding="utf-8") as f:
json.dump(
{
"mean": mean,
"std": std,
"min": vmin,
"max": vmax,
"int_like": int_like,
"max_decimals": max_decimals,
"max_rows": max_rows,
},
f,
indent=2,
)
with open(safe_path(out_vocab), "w", encoding="utf-8") as f:
json.dump({"vocab": vocab, "top_token": top_token, "max_rows": max_rows}, f, indent=2)
if __name__ == "__main__":
args = parse_args()
main(
max_rows=args.max_rows,
split_path=args.split_path,
data_glob=args.data_glob,
out_stats=args.out_stats,
out_vocab=args.out_vocab,
)