Files
mask-ddpm/example/ablation_splits.py
2026-01-23 12:40:20 +08:00

115 lines
4.0 KiB
Python

#!/usr/bin/env python3
"""Generate multiple continuous/discrete splits for ablation."""
import argparse
import json
from pathlib import Path
from data_utils import iter_rows
from platform_utils import resolve_path, safe_path, ensure_dir
def parse_args():
parser = argparse.ArgumentParser(description="Generate split variants for ablation.")
base_dir = Path(__file__).resolve().parent
repo_dir = base_dir.parent.parent
parser.add_argument("--data-glob", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"))
parser.add_argument("--max-rows", type=int, default=50000)
parser.add_argument("--out-dir", default=str(base_dir / "results" / "ablation_splits"))
parser.add_argument("--time-col", default="time")
return parser.parse_args()
def analyze_columns(paths, max_rows, time_col):
stats = {}
rows = 0
for path in paths:
for row in iter_rows(path):
rows += 1
for c, v in row.items():
if c == time_col:
continue
st = stats.setdefault(c, {"numeric": True, "int_like": True, "unique": set(), "count": 0})
if v is None or v == "":
continue
st["count"] += 1
if st["numeric"]:
try:
fv = float(v)
except Exception:
st["numeric"] = False
st["int_like"] = False
st["unique"].add(v)
continue
if st["int_like"] and abs(fv - round(fv)) > 1e-9:
st["int_like"] = False
if len(st["unique"]) < 200:
st["unique"].add(fv)
else:
if len(st["unique"]) < 200:
st["unique"].add(v)
if max_rows is not None and rows >= max_rows:
return stats
return stats
def build_split(stats, time_col, int_ratio=0.98, max_unique=20):
cont = []
disc = []
for c, st in stats.items():
if c == time_col:
continue
if st["count"] == 0:
continue
if not st["numeric"]:
disc.append(c)
continue
unique_count = len(st["unique"])
# if values look integer-like and low unique => discrete
if st["int_like"] and unique_count <= max_unique:
disc.append(c)
else:
cont.append(c)
return cont, disc
def main():
args = parse_args()
base_dir = Path(__file__).resolve().parent
glob_path = resolve_path(base_dir, args.data_glob)
paths = sorted(Path(glob_path).parent.glob(Path(glob_path).name))
if not paths:
raise SystemExit("no train files found under %s" % str(glob_path))
paths = [safe_path(p) for p in paths]
stats = analyze_columns(paths, args.max_rows, args.time_col)
ensure_dir(args.out_dir)
# baseline (current heuristic)
cont, disc = build_split(stats, args.time_col, max_unique=10)
baseline = {"time_column": args.time_col, "continuous": sorted(cont), "discrete": sorted(disc)}
# stricter discrete
cont_s, disc_s = build_split(stats, args.time_col, max_unique=5)
strict = {"time_column": args.time_col, "continuous": sorted(cont_s), "discrete": sorted(disc_s)}
# looser discrete
cont_l, disc_l = build_split(stats, args.time_col, max_unique=30)
loose = {"time_column": args.time_col, "continuous": sorted(cont_l), "discrete": sorted(disc_l)}
out_dir = Path(args.out_dir)
with open(out_dir / "split_baseline.json", "w", encoding="utf-8") as f:
json.dump(baseline, f, indent=2)
with open(out_dir / "split_strict.json", "w", encoding="utf-8") as f:
json.dump(strict, f, indent=2)
with open(out_dir / "split_loose.json", "w", encoding="utf-8") as f:
json.dump(loose, f, indent=2)
print("wrote", out_dir / "split_baseline.json")
print("wrote", out_dir / "split_strict.json")
print("wrote", out_dir / "split_loose.json")
if __name__ == "__main__":
main()