Files
mask-ddpm/example/analyze_hai21_03.py
2026-01-22 17:39:31 +08:00

108 lines
3.3 KiB
Python
Executable File

#!/usr/bin/env python3
"""Analyze HAI 21.03 CSV to split features into continuous/discrete.
Heuristic: integer-like values with low cardinality (<=10) -> discrete.
Everything else numeric -> continuous. Non-numeric -> discrete.
"""
import csv
import gzip
import os
from pathlib import Path
BASE_DIR = Path(__file__).resolve().parent
REPO_DIR = BASE_DIR.parent.parent
DATA_PATH = REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"
OUT_DIR = BASE_DIR / "results"
MAX_ROWS = 5000
def analyze(path: str, max_rows: int):
with gzip.open(str(path), "rt", newline="") as f:
reader = csv.reader(f)
cols = next(reader)
stats = {
c: {"numeric": True, "int_like": True, "unique": set(), "count": 0}
for c in cols
}
rows = 0
for row in reader:
rows += 1
for c, v in zip(cols, row):
st = stats[c]
if v == "" or v is None:
continue
st["count"] += 1
if st["numeric"]:
try:
fv = float(v)
except Exception:
st["numeric"] = False
st["int_like"] = False
st["unique"].add(v)
continue
if st["int_like"] and abs(fv - round(fv)) > 1e-9:
st["int_like"] = False
if len(st["unique"]) < 50:
st["unique"].add(fv)
else:
if len(st["unique"]) < 50:
st["unique"].add(v)
if rows >= max_rows:
break
continuous = []
discrete = []
unknown = []
for c in cols:
if c == "time":
continue
st = stats[c]
if st["count"] == 0:
unknown.append(c)
continue
if not st["numeric"]:
discrete.append(c)
continue
unique_count = len(st["unique"])
if st["int_like"] and unique_count <= 10:
discrete.append(c)
else:
continuous.append(c)
return cols, continuous, discrete, unknown, rows
def write_results(cols, continuous, discrete, unknown, rows):
os.makedirs(str(OUT_DIR), exist_ok=True)
split_path = OUT_DIR / "feature_split.txt"
summary_path = OUT_DIR / "summary.txt"
with open(split_path, "w", encoding="utf-8") as f:
f.write("discrete\n")
f.write(",".join(discrete) + "\n")
f.write("continuous\n")
f.write(",".join(continuous) + "\n")
if unknown:
f.write("unknown\n")
f.write(",".join(unknown) + "\n")
with open(summary_path, "w", encoding="utf-8") as f:
f.write("rows_sampled: %d\n" % rows)
f.write("columns_total: %d\n" % len(cols))
f.write("continuous: %d\n" % len(continuous))
f.write("discrete: %d\n" % len(discrete))
f.write("unknown: %d\n" % len(unknown))
f.write("data_path: %s\n" % str(DATA_PATH))
def main():
if not DATA_PATH.exists():
raise SystemExit("missing data file: %s" % str(DATA_PATH))
cols, continuous, discrete, unknown, rows = analyze(str(DATA_PATH), MAX_ROWS)
write_results(cols, continuous, discrete, unknown, rows)
if __name__ == "__main__":
main()