#!/usr/bin/env python3 """Generate multiple continuous/discrete splits for ablation.""" import argparse import json from pathlib import Path from data_utils import iter_rows from platform_utils import resolve_path, safe_path, ensure_dir def parse_args(): parser = argparse.ArgumentParser(description="Generate split variants for ablation.") base_dir = Path(__file__).resolve().parent repo_dir = base_dir.parent.parent parser.add_argument("--data-glob", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz")) parser.add_argument("--max-rows", type=int, default=50000) parser.add_argument("--out-dir", default=str(base_dir / "results" / "ablation_splits")) parser.add_argument("--time-col", default="time") return parser.parse_args() def analyze_columns(paths, max_rows, time_col): stats = {} rows = 0 for path in paths: for row in iter_rows(path): rows += 1 for c, v in row.items(): if c == time_col: continue st = stats.setdefault(c, {"numeric": True, "int_like": True, "unique": set(), "count": 0}) if v is None or v == "": continue st["count"] += 1 if st["numeric"]: try: fv = float(v) except Exception: st["numeric"] = False st["int_like"] = False st["unique"].add(v) continue if st["int_like"] and abs(fv - round(fv)) > 1e-9: st["int_like"] = False if len(st["unique"]) < 200: st["unique"].add(fv) else: if len(st["unique"]) < 200: st["unique"].add(v) if max_rows is not None and rows >= max_rows: return stats return stats def build_split(stats, time_col, int_ratio=0.98, max_unique=20): cont = [] disc = [] for c, st in stats.items(): if c == time_col: continue if st["count"] == 0: continue if not st["numeric"]: disc.append(c) continue unique_count = len(st["unique"]) # if values look integer-like and low unique => discrete if st["int_like"] and unique_count <= max_unique: disc.append(c) else: cont.append(c) return cont, disc def main(): args = parse_args() base_dir = Path(__file__).resolve().parent glob_path = resolve_path(base_dir, args.data_glob) paths = sorted(Path(glob_path).parent.glob(Path(glob_path).name)) if not paths: raise SystemExit("no train files found under %s" % str(glob_path)) paths = [safe_path(p) for p in paths] stats = analyze_columns(paths, args.max_rows, args.time_col) ensure_dir(args.out_dir) # baseline (current heuristic) cont, disc = build_split(stats, args.time_col, max_unique=10) baseline = {"time_column": args.time_col, "continuous": sorted(cont), "discrete": sorted(disc)} # stricter discrete cont_s, disc_s = build_split(stats, args.time_col, max_unique=5) strict = {"time_column": args.time_col, "continuous": sorted(cont_s), "discrete": sorted(disc_s)} # looser discrete cont_l, disc_l = build_split(stats, args.time_col, max_unique=30) loose = {"time_column": args.time_col, "continuous": sorted(cont_l), "discrete": sorted(disc_l)} out_dir = Path(args.out_dir) with open(out_dir / "split_baseline.json", "w", encoding="utf-8") as f: json.dump(baseline, f, indent=2) with open(out_dir / "split_strict.json", "w", encoding="utf-8") as f: json.dump(strict, f, indent=2) with open(out_dir / "split_loose.json", "w", encoding="utf-8") as f: json.dump(loose, f, indent=2) print("wrote", out_dir / "split_baseline.json") print("wrote", out_dir / "split_strict.json") print("wrote", out_dir / "split_loose.json") if __name__ == "__main__": main()