#!/usr/bin/env python3 """Analyze HAI 21.03 CSV to split features into continuous/discrete. Heuristic: integer-like values with low cardinality (<=10) -> discrete. Everything else numeric -> continuous. Non-numeric -> discrete. """ import csv import gzip import os DATA_PATH = "/home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz" OUT_DIR = "/home/anay/Dev/diffusion/mask-ddpm/example/results" MAX_ROWS = 5000 def analyze(path: str, max_rows: int): with gzip.open(path, "rt", newline="") as f: reader = csv.reader(f) cols = next(reader) stats = { c: {"numeric": True, "int_like": True, "unique": set(), "count": 0} for c in cols } rows = 0 for row in reader: rows += 1 for c, v in zip(cols, row): st = stats[c] if v == "" or v is None: continue st["count"] += 1 if st["numeric"]: try: fv = float(v) except Exception: st["numeric"] = False st["int_like"] = False st["unique"].add(v) continue if st["int_like"] and abs(fv - round(fv)) > 1e-9: st["int_like"] = False if len(st["unique"]) < 50: st["unique"].add(fv) else: if len(st["unique"]) < 50: st["unique"].add(v) if rows >= max_rows: break continuous = [] discrete = [] unknown = [] for c in cols: if c == "time": continue st = stats[c] if st["count"] == 0: unknown.append(c) continue if not st["numeric"]: discrete.append(c) continue unique_count = len(st["unique"]) if st["int_like"] and unique_count <= 10: discrete.append(c) else: continuous.append(c) return cols, continuous, discrete, unknown, rows def write_results(cols, continuous, discrete, unknown, rows): os.makedirs(OUT_DIR, exist_ok=True) split_path = os.path.join(OUT_DIR, "feature_split.txt") summary_path = os.path.join(OUT_DIR, "summary.txt") with open(split_path, "w", encoding="ascii") as f: f.write("discrete\n") f.write(",".join(discrete) + "\n") f.write("continuous\n") f.write(",".join(continuous) + "\n") if unknown: f.write("unknown\n") f.write(",".join(unknown) + "\n") with open(summary_path, "w", encoding="ascii") as f: f.write("rows_sampled: %d\n" % rows) f.write("columns_total: %d\n" % len(cols)) f.write("continuous: %d\n" % len(continuous)) f.write("discrete: %d\n" % len(discrete)) f.write("unknown: %d\n" % len(unknown)) f.write("data_path: %s\n" % DATA_PATH) def main(): if not os.path.exists(DATA_PATH): raise SystemExit("missing data file: %s" % DATA_PATH) cols, continuous, discrete, unknown, rows = analyze(DATA_PATH, MAX_ROWS) write_results(cols, continuous, discrete, unknown, rows) if __name__ == "__main__": main()