77 lines
2.9 KiB
Python
Executable File
77 lines
2.9 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Prepare vocab and normalization stats for HAI 21.03."""
|
|
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from data_utils import compute_cont_stats, build_disc_stats, load_split, choose_cont_transforms
|
|
from platform_utils import safe_path, ensure_dir
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent
|
|
REPO_DIR = BASE_DIR.parent.parent
|
|
DATA_GLOB = REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"
|
|
SPLIT_PATH = BASE_DIR / "feature_split.json"
|
|
OUT_STATS = BASE_DIR / "results" / "cont_stats.json"
|
|
OUT_VOCAB = BASE_DIR / "results" / "disc_vocab.json"
|
|
|
|
|
|
def main(max_rows: Optional[int] = None):
|
|
config_path = BASE_DIR / "config.json"
|
|
use_quantile = False
|
|
quantile_bins = None
|
|
if config_path.exists():
|
|
cfg = json.loads(config_path.read_text(encoding="utf-8"))
|
|
use_quantile = bool(cfg.get("use_quantile_transform", False))
|
|
quantile_bins = int(cfg.get("quantile_bins", 0)) if use_quantile else None
|
|
|
|
split = load_split(safe_path(SPLIT_PATH))
|
|
time_col = split.get("time_column", "time")
|
|
cont_cols = [c for c in split["continuous"] if c != time_col]
|
|
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
|
|
|
data_paths = sorted(Path(REPO_DIR / "dataset" / "hai" / "hai-21.03").glob("train*.csv.gz"))
|
|
if not data_paths:
|
|
raise SystemExit("no train files found under %s" % str(DATA_GLOB))
|
|
data_paths = [safe_path(p) for p in data_paths]
|
|
|
|
transforms, _ = choose_cont_transforms(data_paths, cont_cols, max_rows=max_rows)
|
|
cont_stats = compute_cont_stats(
|
|
data_paths,
|
|
cont_cols,
|
|
max_rows=max_rows,
|
|
transforms=transforms,
|
|
quantile_bins=quantile_bins,
|
|
)
|
|
vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows)
|
|
|
|
ensure_dir(OUT_STATS.parent)
|
|
with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f:
|
|
json.dump(
|
|
{
|
|
"mean": cont_stats["mean"],
|
|
"std": cont_stats["std"],
|
|
"raw_mean": cont_stats["raw_mean"],
|
|
"raw_std": cont_stats["raw_std"],
|
|
"min": cont_stats["min"],
|
|
"max": cont_stats["max"],
|
|
"int_like": cont_stats["int_like"],
|
|
"max_decimals": cont_stats["max_decimals"],
|
|
"transform": cont_stats["transform"],
|
|
"skew": cont_stats["skew"],
|
|
"max_rows": cont_stats["max_rows"],
|
|
"quantile_probs": cont_stats["quantile_probs"],
|
|
"quantile_values": cont_stats["quantile_values"],
|
|
},
|
|
f,
|
|
indent=2,
|
|
)
|
|
|
|
with open(safe_path(OUT_VOCAB), "w", encoding="utf-8") as f:
|
|
json.dump({"vocab": vocab, "top_token": top_token, "max_rows": max_rows}, f, indent=2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Default: sample 50000 rows for speed. Set to None for full scan.
|
|
main(max_rows=50000)
|