#!/usr/bin/env python3 """Prepare vocab and normalization stats for HAI-style CSV datasets.""" import argparse import json from pathlib import Path from typing import Optional from data_utils import compute_cont_stats, build_disc_stats, load_split, choose_cont_transforms from platform_utils import safe_path, ensure_dir, resolve_path BASE_DIR = Path(__file__).resolve().parent REPO_DIR = BASE_DIR.parent.parent def parse_args(): parser = argparse.ArgumentParser(description="Prepare vocab and normalization stats.") parser.add_argument("--config", default=str(BASE_DIR / "config.json"), help="Path to JSON config") parser.add_argument("--max-rows", type=int, default=50000, help="Sample cap for stats; ignored when full_stats=true") return parser.parse_args() def resolve_data_paths(cfg: dict, cfg_path: Path) -> list[str]: base_dir = cfg_path.parent data_glob = cfg.get("data_glob", "") data_path = cfg.get("data_path", "") paths = [] if data_glob: resolved_glob = resolve_path(base_dir, data_glob) paths = sorted(Path(resolved_glob).parent.glob(Path(resolved_glob).name)) elif data_path: resolved_path = resolve_path(base_dir, data_path) if Path(resolved_path).exists(): paths = [Path(resolved_path)] return [safe_path(p) for p in paths] def main(): args = parse_args() config_path = Path(args.config) if not config_path.is_absolute(): config_path = resolve_path(BASE_DIR, config_path) if not config_path.exists(): raise SystemExit(f"missing config: {config_path}") cfg = json.loads(config_path.read_text(encoding="utf-8")) use_quantile = bool(cfg.get("use_quantile_transform", False)) quantile_bins = int(cfg.get("quantile_bins", 0)) if use_quantile else None full_stats = bool(cfg.get("full_stats", False)) max_rows: Optional[int] = args.max_rows if full_stats: max_rows = None split_path = resolve_path(config_path.parent, cfg.get("split_path", "./feature_split.json")) split = load_split(safe_path(split_path)) time_col = split.get("time_column", "time") cont_cols = [c for c in split["continuous"] if c != time_col] disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col] data_paths = resolve_data_paths(cfg, config_path) if not data_paths: raise SystemExit(f"no train files found for config: {config_path}") transforms, _ = choose_cont_transforms(data_paths, cont_cols, max_rows=max_rows) cont_stats = compute_cont_stats( data_paths, cont_cols, max_rows=max_rows, transforms=transforms, quantile_bins=quantile_bins, ) vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows) out_stats = resolve_path(config_path.parent, cfg.get("stats_path", "./results/cont_stats.json")) out_vocab = resolve_path(config_path.parent, cfg.get("vocab_path", "./results/disc_vocab.json")) ensure_dir(out_stats.parent) ensure_dir(out_vocab.parent) with open(safe_path(out_stats), "w", encoding="utf-8") as f: json.dump( { "mean": cont_stats["mean"], "std": cont_stats["std"], "raw_mean": cont_stats["raw_mean"], "raw_std": cont_stats["raw_std"], "min": cont_stats["min"], "max": cont_stats["max"], "int_like": cont_stats["int_like"], "max_decimals": cont_stats["max_decimals"], "transform": cont_stats["transform"], "skew": cont_stats["skew"], "max_rows": cont_stats["max_rows"], "quantile_probs": cont_stats["quantile_probs"], "quantile_values": cont_stats["quantile_values"], "quantile_raw_values": cont_stats["quantile_raw_values"], }, f, indent=2, ) with open(safe_path(out_vocab), "w", encoding="utf-8") as f: json.dump({"vocab": vocab, "top_token": top_token, "max_rows": max_rows}, f, indent=2) if __name__ == "__main__": main()