update
This commit is contained in:
114
example/ablation_splits.py
Normal file
114
example/ablation_splits.py
Normal file
@@ -0,0 +1,114 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate multiple continuous/discrete splits for ablation."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from data_utils import iter_rows
|
||||
from platform_utils import resolve_path, safe_path, ensure_dir
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Generate split variants for ablation.")
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
repo_dir = base_dir.parent.parent
|
||||
parser.add_argument("--data-glob", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"))
|
||||
parser.add_argument("--max-rows", type=int, default=50000)
|
||||
parser.add_argument("--out-dir", default=str(base_dir / "results" / "ablation_splits"))
|
||||
parser.add_argument("--time-col", default="time")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def analyze_columns(paths, max_rows, time_col):
|
||||
stats = {}
|
||||
rows = 0
|
||||
for path in paths:
|
||||
for row in iter_rows(path):
|
||||
rows += 1
|
||||
for c, v in row.items():
|
||||
if c == time_col:
|
||||
continue
|
||||
st = stats.setdefault(c, {"numeric": True, "int_like": True, "unique": set(), "count": 0})
|
||||
if v is None or v == "":
|
||||
continue
|
||||
st["count"] += 1
|
||||
if st["numeric"]:
|
||||
try:
|
||||
fv = float(v)
|
||||
except Exception:
|
||||
st["numeric"] = False
|
||||
st["int_like"] = False
|
||||
st["unique"].add(v)
|
||||
continue
|
||||
if st["int_like"] and abs(fv - round(fv)) > 1e-9:
|
||||
st["int_like"] = False
|
||||
if len(st["unique"]) < 200:
|
||||
st["unique"].add(fv)
|
||||
else:
|
||||
if len(st["unique"]) < 200:
|
||||
st["unique"].add(v)
|
||||
if max_rows is not None and rows >= max_rows:
|
||||
return stats
|
||||
return stats
|
||||
|
||||
|
||||
def build_split(stats, time_col, int_ratio=0.98, max_unique=20):
|
||||
cont = []
|
||||
disc = []
|
||||
for c, st in stats.items():
|
||||
if c == time_col:
|
||||
continue
|
||||
if st["count"] == 0:
|
||||
continue
|
||||
if not st["numeric"]:
|
||||
disc.append(c)
|
||||
continue
|
||||
unique_count = len(st["unique"])
|
||||
# if values look integer-like and low unique => discrete
|
||||
if st["int_like"] and unique_count <= max_unique:
|
||||
disc.append(c)
|
||||
else:
|
||||
cont.append(c)
|
||||
return cont, disc
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
glob_path = resolve_path(base_dir, args.data_glob)
|
||||
paths = sorted(Path(glob_path).parent.glob(Path(glob_path).name))
|
||||
if not paths:
|
||||
raise SystemExit("no train files found under %s" % str(glob_path))
|
||||
paths = [safe_path(p) for p in paths]
|
||||
|
||||
stats = analyze_columns(paths, args.max_rows, args.time_col)
|
||||
ensure_dir(args.out_dir)
|
||||
|
||||
# baseline (current heuristic)
|
||||
cont, disc = build_split(stats, args.time_col, max_unique=10)
|
||||
baseline = {"time_column": args.time_col, "continuous": sorted(cont), "discrete": sorted(disc)}
|
||||
|
||||
# stricter discrete
|
||||
cont_s, disc_s = build_split(stats, args.time_col, max_unique=5)
|
||||
strict = {"time_column": args.time_col, "continuous": sorted(cont_s), "discrete": sorted(disc_s)}
|
||||
|
||||
# looser discrete
|
||||
cont_l, disc_l = build_split(stats, args.time_col, max_unique=30)
|
||||
loose = {"time_column": args.time_col, "continuous": sorted(cont_l), "discrete": sorted(disc_l)}
|
||||
|
||||
out_dir = Path(args.out_dir)
|
||||
with open(out_dir / "split_baseline.json", "w", encoding="utf-8") as f:
|
||||
json.dump(baseline, f, indent=2)
|
||||
with open(out_dir / "split_strict.json", "w", encoding="utf-8") as f:
|
||||
json.dump(strict, f, indent=2)
|
||||
with open(out_dir / "split_loose.json", "w", encoding="utf-8") as f:
|
||||
json.dump(loose, f, indent=2)
|
||||
|
||||
print("wrote", out_dir / "split_baseline.json")
|
||||
print("wrote", out_dir / "split_strict.json")
|
||||
print("wrote", out_dir / "split_loose.json")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user