diff --git a/example/README.md b/example/README.md index 32ffe81..4d7d7f3 100644 --- a/example/README.md +++ b/example/README.md @@ -54,6 +54,25 @@ One-click pipeline (prepare -> train -> export -> eval -> plot): python example/run_pipeline.py --device auto ``` +## Ablation: Feature Split Variants +Generate alternative continuous/discrete splits (baseline/strict/loose): +``` +python example/ablation_splits.py --data-glob "../../dataset/hai/hai-21.03/train*.csv.gz" +``` + +Then run prepare/train with a chosen split: +``` +python example/prepare_data.py --split-path example/results/ablation_splits/split_strict.json +python example/train.py --config example/config.json --device cuda +``` + +Update `example/config.json` to point `split_path` at the chosen split file. + +One-click ablation (runs baseline/strict/loose end-to-end): +``` +python example/run_ablation.py --device cuda +``` + ## Notes - Heuristic: integer-like values with low cardinality (<=10) are treated as discrete. All other numeric columns are continuous. diff --git a/example/ablation_splits.py b/example/ablation_splits.py new file mode 100644 index 0000000..1a6a9db --- /dev/null +++ b/example/ablation_splits.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 +"""Generate multiple continuous/discrete splits for ablation.""" + +import argparse +import json +from pathlib import Path + +from data_utils import iter_rows +from platform_utils import resolve_path, safe_path, ensure_dir + + +def parse_args(): + parser = argparse.ArgumentParser(description="Generate split variants for ablation.") + base_dir = Path(__file__).resolve().parent + repo_dir = base_dir.parent.parent + parser.add_argument("--data-glob", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz")) + parser.add_argument("--max-rows", type=int, default=50000) + parser.add_argument("--out-dir", default=str(base_dir / "results" / "ablation_splits")) + parser.add_argument("--time-col", default="time") + return parser.parse_args() + + +def analyze_columns(paths, max_rows, time_col): + stats = {} + rows = 0 + for path in paths: + for row in iter_rows(path): + rows += 1 + for c, v in row.items(): + if c == time_col: + continue + st = stats.setdefault(c, {"numeric": True, "int_like": True, "unique": set(), "count": 0}) + if v is None or v == "": + continue + st["count"] += 1 + if st["numeric"]: + try: + fv = float(v) + except Exception: + st["numeric"] = False + st["int_like"] = False + st["unique"].add(v) + continue + if st["int_like"] and abs(fv - round(fv)) > 1e-9: + st["int_like"] = False + if len(st["unique"]) < 200: + st["unique"].add(fv) + else: + if len(st["unique"]) < 200: + st["unique"].add(v) + if max_rows is not None and rows >= max_rows: + return stats + return stats + + +def build_split(stats, time_col, int_ratio=0.98, max_unique=20): + cont = [] + disc = [] + for c, st in stats.items(): + if c == time_col: + continue + if st["count"] == 0: + continue + if not st["numeric"]: + disc.append(c) + continue + unique_count = len(st["unique"]) + # if values look integer-like and low unique => discrete + if st["int_like"] and unique_count <= max_unique: + disc.append(c) + else: + cont.append(c) + return cont, disc + + +def main(): + args = parse_args() + base_dir = Path(__file__).resolve().parent + glob_path = resolve_path(base_dir, args.data_glob) + paths = sorted(Path(glob_path).parent.glob(Path(glob_path).name)) + if not paths: + raise SystemExit("no train files found under %s" % str(glob_path)) + paths = [safe_path(p) for p in paths] + + stats = analyze_columns(paths, args.max_rows, args.time_col) + ensure_dir(args.out_dir) + + # baseline (current heuristic) + cont, disc = build_split(stats, args.time_col, max_unique=10) + baseline = {"time_column": args.time_col, "continuous": sorted(cont), "discrete": sorted(disc)} + + # stricter discrete + cont_s, disc_s = build_split(stats, args.time_col, max_unique=5) + strict = {"time_column": args.time_col, "continuous": sorted(cont_s), "discrete": sorted(disc_s)} + + # looser discrete + cont_l, disc_l = build_split(stats, args.time_col, max_unique=30) + loose = {"time_column": args.time_col, "continuous": sorted(cont_l), "discrete": sorted(disc_l)} + + out_dir = Path(args.out_dir) + with open(out_dir / "split_baseline.json", "w", encoding="utf-8") as f: + json.dump(baseline, f, indent=2) + with open(out_dir / "split_strict.json", "w", encoding="utf-8") as f: + json.dump(strict, f, indent=2) + with open(out_dir / "split_loose.json", "w", encoding="utf-8") as f: + json.dump(loose, f, indent=2) + + print("wrote", out_dir / "split_baseline.json") + print("wrote", out_dir / "split_strict.json") + print("wrote", out_dir / "split_loose.json") + + +if __name__ == "__main__": + main() diff --git a/example/prepare_data.py b/example/prepare_data.py index 26eac22..e90d5b1 100755 --- a/example/prepare_data.py +++ b/example/prepare_data.py @@ -1,12 +1,13 @@ #!/usr/bin/env python3 """Prepare vocab and normalization stats for HAI 21.03.""" +import argparse import json from pathlib import Path from typing import Optional from data_utils import compute_cont_stats, build_disc_stats, load_split -from platform_utils import safe_path, ensure_dir +from platform_utils import safe_path, ensure_dir, resolve_path BASE_DIR = Path(__file__).resolve().parent REPO_DIR = BASE_DIR.parent.parent @@ -16,22 +17,39 @@ OUT_STATS = BASE_DIR / "results" / "cont_stats.json" OUT_VOCAB = BASE_DIR / "results" / "disc_vocab.json" -def main(max_rows: Optional[int] = None): - split = load_split(safe_path(SPLIT_PATH)) +def parse_args(): + parser = argparse.ArgumentParser(description="Prepare vocab and stats for HAI.") + parser.add_argument("--data-glob", default=str(DATA_GLOB), help="Glob for train CSVs") + parser.add_argument("--split-path", default=str(SPLIT_PATH), help="Split JSON path") + parser.add_argument("--out-stats", default=str(OUT_STATS), help="Output stats JSON") + parser.add_argument("--out-vocab", default=str(OUT_VOCAB), help="Output vocab JSON") + parser.add_argument("--max-rows", type=int, default=50000, help="Row cap for speed") + return parser.parse_args() + + +def main(max_rows: Optional[int] = None, split_path: Optional[str] = None, data_glob: Optional[str] = None, + out_stats: Optional[str] = None, out_vocab: Optional[str] = None): + split_path = split_path or str(SPLIT_PATH) + data_glob = data_glob or str(DATA_GLOB) + out_stats = out_stats or str(OUT_STATS) + out_vocab = out_vocab or str(OUT_VOCAB) + + split = load_split(safe_path(split_path)) time_col = split.get("time_column", "time") cont_cols = [c for c in split["continuous"] if c != time_col] disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col] - data_paths = sorted(Path(REPO_DIR / "dataset" / "hai" / "hai-21.03").glob("train*.csv.gz")) + glob_path = resolve_path(BASE_DIR, data_glob) + data_paths = sorted(Path(glob_path).parent.glob(Path(glob_path).name)) if not data_paths: - raise SystemExit("no train files found under %s" % str(DATA_GLOB)) + raise SystemExit("no train files found under %s" % str(glob_path)) data_paths = [safe_path(p) for p in data_paths] mean, std, vmin, vmax, int_like, max_decimals = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows) vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows) - ensure_dir(OUT_STATS.parent) - with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f: + ensure_dir(Path(out_stats).parent) + with open(safe_path(out_stats), "w", encoding="utf-8") as f: json.dump( { "mean": mean, @@ -46,10 +64,16 @@ def main(max_rows: Optional[int] = None): indent=2, ) - with open(safe_path(OUT_VOCAB), "w", encoding="utf-8") as f: + with open(safe_path(out_vocab), "w", encoding="utf-8") as f: json.dump({"vocab": vocab, "top_token": top_token, "max_rows": max_rows}, f, indent=2) if __name__ == "__main__": - # Default: sample 50000 rows for speed. Set to None for full scan. - main(max_rows=50000) + args = parse_args() + main( + max_rows=args.max_rows, + split_path=args.split_path, + data_glob=args.data_glob, + out_stats=args.out_stats, + out_vocab=args.out_vocab, + ) diff --git a/example/run_ablation.py b/example/run_ablation.py new file mode 100644 index 0000000..e2a1792 --- /dev/null +++ b/example/run_ablation.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +"""One-click ablation runner for split variants.""" + +import argparse +import json +import subprocess +import sys +from pathlib import Path + +from platform_utils import safe_path, is_windows + + +def run(cmd): + cmd = [safe_path(c) for c in cmd] + if is_windows(): + subprocess.run(cmd, check=True, shell=False) + else: + subprocess.run(cmd, check=True) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run ablations over split variants.") + base_dir = Path(__file__).resolve().parent + parser.add_argument("--device", default="auto") + parser.add_argument("--config", default=str(base_dir / "config.json")) + parser.add_argument("--data-glob", default=str(base_dir.parent.parent / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz")) + parser.add_argument("--max-rows", type=int, default=50000) + return parser.parse_args() + + +def main(): + args = parse_args() + base_dir = Path(__file__).resolve().parent + results_dir = base_dir / "results" + splits_dir = results_dir / "ablation_splits" + splits_dir.mkdir(parents=True, exist_ok=True) + + # generate splits + run([sys.executable, str(base_dir / "ablation_splits.py"), "--data-glob", args.data_glob, "--max-rows", str(args.max_rows)]) + + split_files = [ + splits_dir / "split_baseline.json", + splits_dir / "split_strict.json", + splits_dir / "split_loose.json", + ] + + for split_path in split_files: + tag = split_path.stem + run([ + sys.executable, + str(base_dir / "prepare_data.py"), + "--data-glob", + args.data_glob, + "--split-path", + str(split_path), + "--out-stats", + str(results_dir / f"cont_stats_{tag}.json"), + "--out-vocab", + str(results_dir / f"disc_vocab_{tag}.json"), + ]) + + # load base config, override split/stats/vocab/out_dir + cfg_path = Path(args.config) + cfg = json.loads(cfg_path.read_text(encoding="utf-8")) + cfg["split_path"] = str(split_path) + cfg["stats_path"] = str(results_dir / f"cont_stats_{tag}.json") + cfg["vocab_path"] = str(results_dir / f"disc_vocab_{tag}.json") + cfg["out_dir"] = str(results_dir / f"ablation_{tag}") + + temp_cfg = results_dir / f"config_{tag}.json" + temp_cfg.write_text(json.dumps(cfg, indent=2), encoding="utf-8") + + run([sys.executable, str(base_dir / "train.py"), "--config", str(temp_cfg), "--device", args.device]) + run([ + sys.executable, + str(base_dir / "export_samples.py"), + "--include-time", + "--device", + args.device, + "--config", + str(temp_cfg), + "--timesteps", + str(cfg.get("timesteps", 400)), + "--seq-len", + str(cfg.get("sample_seq_len", cfg.get("seq_len", 128))), + "--batch-size", + str(cfg.get("sample_batch_size", 8)), + "--clip-k", + str(cfg.get("clip_k", 3.0)), + "--use-ema", + ]) + run([sys.executable, str(base_dir / "evaluate_generated.py"), "--out", str(results_dir / f"ablation_{tag}" / "eval.json")]) + + +if __name__ == "__main__": + main()