115 lines
4.0 KiB
Python
115 lines
4.0 KiB
Python
#!/usr/bin/env python3
|
|
"""Generate multiple continuous/discrete splits for ablation."""
|
|
|
|
import argparse
|
|
import json
|
|
from pathlib import Path
|
|
|
|
from data_utils import iter_rows
|
|
from platform_utils import resolve_path, safe_path, ensure_dir
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Generate split variants for ablation.")
|
|
base_dir = Path(__file__).resolve().parent
|
|
repo_dir = base_dir.parent.parent
|
|
parser.add_argument("--data-glob", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"))
|
|
parser.add_argument("--max-rows", type=int, default=50000)
|
|
parser.add_argument("--out-dir", default=str(base_dir / "results" / "ablation_splits"))
|
|
parser.add_argument("--time-col", default="time")
|
|
return parser.parse_args()
|
|
|
|
|
|
def analyze_columns(paths, max_rows, time_col):
|
|
stats = {}
|
|
rows = 0
|
|
for path in paths:
|
|
for row in iter_rows(path):
|
|
rows += 1
|
|
for c, v in row.items():
|
|
if c == time_col:
|
|
continue
|
|
st = stats.setdefault(c, {"numeric": True, "int_like": True, "unique": set(), "count": 0})
|
|
if v is None or v == "":
|
|
continue
|
|
st["count"] += 1
|
|
if st["numeric"]:
|
|
try:
|
|
fv = float(v)
|
|
except Exception:
|
|
st["numeric"] = False
|
|
st["int_like"] = False
|
|
st["unique"].add(v)
|
|
continue
|
|
if st["int_like"] and abs(fv - round(fv)) > 1e-9:
|
|
st["int_like"] = False
|
|
if len(st["unique"]) < 200:
|
|
st["unique"].add(fv)
|
|
else:
|
|
if len(st["unique"]) < 200:
|
|
st["unique"].add(v)
|
|
if max_rows is not None and rows >= max_rows:
|
|
return stats
|
|
return stats
|
|
|
|
|
|
def build_split(stats, time_col, int_ratio=0.98, max_unique=20):
|
|
cont = []
|
|
disc = []
|
|
for c, st in stats.items():
|
|
if c == time_col:
|
|
continue
|
|
if st["count"] == 0:
|
|
continue
|
|
if not st["numeric"]:
|
|
disc.append(c)
|
|
continue
|
|
unique_count = len(st["unique"])
|
|
# if values look integer-like and low unique => discrete
|
|
if st["int_like"] and unique_count <= max_unique:
|
|
disc.append(c)
|
|
else:
|
|
cont.append(c)
|
|
return cont, disc
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
base_dir = Path(__file__).resolve().parent
|
|
glob_path = resolve_path(base_dir, args.data_glob)
|
|
paths = sorted(Path(glob_path).parent.glob(Path(glob_path).name))
|
|
if not paths:
|
|
raise SystemExit("no train files found under %s" % str(glob_path))
|
|
paths = [safe_path(p) for p in paths]
|
|
|
|
stats = analyze_columns(paths, args.max_rows, args.time_col)
|
|
ensure_dir(args.out_dir)
|
|
|
|
# baseline (current heuristic)
|
|
cont, disc = build_split(stats, args.time_col, max_unique=10)
|
|
baseline = {"time_column": args.time_col, "continuous": sorted(cont), "discrete": sorted(disc)}
|
|
|
|
# stricter discrete
|
|
cont_s, disc_s = build_split(stats, args.time_col, max_unique=5)
|
|
strict = {"time_column": args.time_col, "continuous": sorted(cont_s), "discrete": sorted(disc_s)}
|
|
|
|
# looser discrete
|
|
cont_l, disc_l = build_split(stats, args.time_col, max_unique=30)
|
|
loose = {"time_column": args.time_col, "continuous": sorted(cont_l), "discrete": sorted(disc_l)}
|
|
|
|
out_dir = Path(args.out_dir)
|
|
with open(out_dir / "split_baseline.json", "w", encoding="utf-8") as f:
|
|
json.dump(baseline, f, indent=2)
|
|
with open(out_dir / "split_strict.json", "w", encoding="utf-8") as f:
|
|
json.dump(strict, f, indent=2)
|
|
with open(out_dir / "split_loose.json", "w", encoding="utf-8") as f:
|
|
json.dump(loose, f, indent=2)
|
|
|
|
print("wrote", out_dir / "split_baseline.json")
|
|
print("wrote", out_dir / "split_strict.json")
|
|
print("wrote", out_dir / "split_loose.json")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|