#!/usr/bin/env python3 """Analyze HAI 21.03 CSV to split features into continuous/discrete. Heuristic: integer-like values with low cardinality (<=10) -> discrete. Everything else numeric -> continuous. Non-numeric -> discrete. """ import csv import gzip import os from pathlib import Path BASE_DIR = Path(__file__).resolve().parent REPO_DIR = BASE_DIR.parent.parent DATA_PATH = REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz" OUT_DIR = BASE_DIR / "results" MAX_ROWS = 5000 def analyze(path: str, max_rows: int): with gzip.open(str(path), "rt", newline="") as f: reader = csv.reader(f) cols = next(reader) stats = { c: {"numeric": True, "int_like": True, "unique": set(), "count": 0} for c in cols } rows = 0 for row in reader: rows += 1 for c, v in zip(cols, row): st = stats[c] if v == "" or v is None: continue st["count"] += 1 if st["numeric"]: try: fv = float(v) except Exception: st["numeric"] = False st["int_like"] = False st["unique"].add(v) continue if st["int_like"] and abs(fv - round(fv)) > 1e-9: st["int_like"] = False if len(st["unique"]) < 50: st["unique"].add(fv) else: if len(st["unique"]) < 50: st["unique"].add(v) if rows >= max_rows: break continuous = [] discrete = [] unknown = [] for c in cols: if c == "time": continue st = stats[c] if st["count"] == 0: unknown.append(c) continue if not st["numeric"]: discrete.append(c) continue unique_count = len(st["unique"]) if st["int_like"] and unique_count <= 10: discrete.append(c) else: continuous.append(c) return cols, continuous, discrete, unknown, rows def write_results(cols, continuous, discrete, unknown, rows): os.makedirs(str(OUT_DIR), exist_ok=True) split_path = OUT_DIR / "feature_split.txt" summary_path = OUT_DIR / "summary.txt" with open(split_path, "w", encoding="utf-8") as f: f.write("discrete\n") f.write(",".join(discrete) + "\n") f.write("continuous\n") f.write(",".join(continuous) + "\n") if unknown: f.write("unknown\n") f.write(",".join(unknown) + "\n") with open(summary_path, "w", encoding="utf-8") as f: f.write("rows_sampled: %d\n" % rows) f.write("columns_total: %d\n" % len(cols)) f.write("continuous: %d\n" % len(continuous)) f.write("discrete: %d\n" % len(discrete)) f.write("unknown: %d\n" % len(unknown)) f.write("data_path: %s\n" % str(DATA_PATH)) def main(): if not DATA_PATH.exists(): raise SystemExit("missing data file: %s" % str(DATA_PATH)) cols, continuous, discrete, unknown, rows = analyze(str(DATA_PATH), MAX_ROWS) write_results(cols, continuous, discrete, unknown, rows) if __name__ == "__main__": main()