Add comprehensive evaluation and ablation runner
This commit is contained in:
@@ -1,44 +1,66 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Prepare vocab and normalization stats for HAI 21.03."""
|
||||
"""Prepare vocab and normalization stats for HAI-style CSV datasets."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from data_utils import compute_cont_stats, build_disc_stats, load_split, choose_cont_transforms
|
||||
from platform_utils import safe_path, ensure_dir
|
||||
from platform_utils import safe_path, ensure_dir, resolve_path
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
REPO_DIR = BASE_DIR.parent.parent
|
||||
DATA_GLOB = REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"
|
||||
SPLIT_PATH = BASE_DIR / "feature_split.json"
|
||||
OUT_STATS = BASE_DIR / "results" / "cont_stats.json"
|
||||
OUT_VOCAB = BASE_DIR / "results" / "disc_vocab.json"
|
||||
|
||||
|
||||
def main(max_rows: Optional[int] = None):
|
||||
config_path = BASE_DIR / "config.json"
|
||||
use_quantile = False
|
||||
quantile_bins = None
|
||||
full_stats = False
|
||||
if config_path.exists():
|
||||
cfg = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
use_quantile = bool(cfg.get("use_quantile_transform", False))
|
||||
quantile_bins = int(cfg.get("quantile_bins", 0)) if use_quantile else None
|
||||
full_stats = bool(cfg.get("full_stats", False))
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Prepare vocab and normalization stats.")
|
||||
parser.add_argument("--config", default=str(BASE_DIR / "config.json"), help="Path to JSON config")
|
||||
parser.add_argument("--max-rows", type=int, default=50000, help="Sample cap for stats; ignored when full_stats=true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def resolve_data_paths(cfg: dict, cfg_path: Path) -> list[str]:
|
||||
base_dir = cfg_path.parent
|
||||
data_glob = cfg.get("data_glob", "")
|
||||
data_path = cfg.get("data_path", "")
|
||||
paths = []
|
||||
if data_glob:
|
||||
resolved_glob = resolve_path(base_dir, data_glob)
|
||||
paths = sorted(Path(resolved_glob).parent.glob(Path(resolved_glob).name))
|
||||
elif data_path:
|
||||
resolved_path = resolve_path(base_dir, data_path)
|
||||
if Path(resolved_path).exists():
|
||||
paths = [Path(resolved_path)]
|
||||
return [safe_path(p) for p in paths]
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
config_path = Path(args.config)
|
||||
if not config_path.is_absolute():
|
||||
config_path = resolve_path(BASE_DIR, config_path)
|
||||
if not config_path.exists():
|
||||
raise SystemExit(f"missing config: {config_path}")
|
||||
|
||||
cfg = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
use_quantile = bool(cfg.get("use_quantile_transform", False))
|
||||
quantile_bins = int(cfg.get("quantile_bins", 0)) if use_quantile else None
|
||||
full_stats = bool(cfg.get("full_stats", False))
|
||||
max_rows: Optional[int] = args.max_rows
|
||||
|
||||
if full_stats:
|
||||
max_rows = None
|
||||
|
||||
split = load_split(safe_path(SPLIT_PATH))
|
||||
split_path = resolve_path(config_path.parent, cfg.get("split_path", "./feature_split.json"))
|
||||
split = load_split(safe_path(split_path))
|
||||
time_col = split.get("time_column", "time")
|
||||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
||||
|
||||
data_paths = sorted(Path(REPO_DIR / "dataset" / "hai" / "hai-21.03").glob("train*.csv.gz"))
|
||||
data_paths = resolve_data_paths(cfg, config_path)
|
||||
if not data_paths:
|
||||
raise SystemExit("no train files found under %s" % str(DATA_GLOB))
|
||||
data_paths = [safe_path(p) for p in data_paths]
|
||||
raise SystemExit(f"no train files found for config: {config_path}")
|
||||
|
||||
transforms, _ = choose_cont_transforms(data_paths, cont_cols, max_rows=max_rows)
|
||||
cont_stats = compute_cont_stats(
|
||||
@@ -50,8 +72,12 @@ def main(max_rows: Optional[int] = None):
|
||||
)
|
||||
vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows)
|
||||
|
||||
ensure_dir(OUT_STATS.parent)
|
||||
with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f:
|
||||
out_stats = resolve_path(config_path.parent, cfg.get("stats_path", "./results/cont_stats.json"))
|
||||
out_vocab = resolve_path(config_path.parent, cfg.get("vocab_path", "./results/disc_vocab.json"))
|
||||
ensure_dir(out_stats.parent)
|
||||
ensure_dir(out_vocab.parent)
|
||||
|
||||
with open(safe_path(out_stats), "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{
|
||||
"mean": cont_stats["mean"],
|
||||
@@ -73,10 +99,9 @@ def main(max_rows: Optional[int] = None):
|
||||
indent=2,
|
||||
)
|
||||
|
||||
with open(safe_path(OUT_VOCAB), "w", encoding="utf-8") as f:
|
||||
with open(safe_path(out_vocab), "w", encoding="utf-8") as f:
|
||||
json.dump({"vocab": vocab, "top_token": top_token, "max_rows": max_rows}, f, indent=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Default: sample 50000 rows for speed. Set to None for full scan.
|
||||
main(max_rows=50000)
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user