Update example and notes
This commit is contained in:
104
example/analyze_hai21_03.py
Executable file
104
example/analyze_hai21_03.py
Executable file
@@ -0,0 +1,104 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Analyze HAI 21.03 CSV to split features into continuous/discrete.
|
||||
|
||||
Heuristic: integer-like values with low cardinality (<=10) -> discrete.
|
||||
Everything else numeric -> continuous. Non-numeric -> discrete.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import gzip
|
||||
import os
|
||||
|
||||
DATA_PATH = "/home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz"
|
||||
OUT_DIR = "/home/anay/Dev/diffusion/mask-ddpm/example/results"
|
||||
MAX_ROWS = 5000
|
||||
|
||||
|
||||
def analyze(path: str, max_rows: int):
|
||||
with gzip.open(path, "rt", newline="") as f:
|
||||
reader = csv.reader(f)
|
||||
cols = next(reader)
|
||||
stats = {
|
||||
c: {"numeric": True, "int_like": True, "unique": set(), "count": 0}
|
||||
for c in cols
|
||||
}
|
||||
rows = 0
|
||||
for row in reader:
|
||||
rows += 1
|
||||
for c, v in zip(cols, row):
|
||||
st = stats[c]
|
||||
if v == "" or v is None:
|
||||
continue
|
||||
st["count"] += 1
|
||||
if st["numeric"]:
|
||||
try:
|
||||
fv = float(v)
|
||||
except Exception:
|
||||
st["numeric"] = False
|
||||
st["int_like"] = False
|
||||
st["unique"].add(v)
|
||||
continue
|
||||
if st["int_like"] and abs(fv - round(fv)) > 1e-9:
|
||||
st["int_like"] = False
|
||||
if len(st["unique"]) < 50:
|
||||
st["unique"].add(fv)
|
||||
else:
|
||||
if len(st["unique"]) < 50:
|
||||
st["unique"].add(v)
|
||||
if rows >= max_rows:
|
||||
break
|
||||
|
||||
continuous = []
|
||||
discrete = []
|
||||
unknown = []
|
||||
for c in cols:
|
||||
if c == "time":
|
||||
continue
|
||||
st = stats[c]
|
||||
if st["count"] == 0:
|
||||
unknown.append(c)
|
||||
continue
|
||||
if not st["numeric"]:
|
||||
discrete.append(c)
|
||||
continue
|
||||
unique_count = len(st["unique"])
|
||||
if st["int_like"] and unique_count <= 10:
|
||||
discrete.append(c)
|
||||
else:
|
||||
continuous.append(c)
|
||||
|
||||
return cols, continuous, discrete, unknown, rows
|
||||
|
||||
|
||||
def write_results(cols, continuous, discrete, unknown, rows):
|
||||
os.makedirs(OUT_DIR, exist_ok=True)
|
||||
split_path = os.path.join(OUT_DIR, "feature_split.txt")
|
||||
summary_path = os.path.join(OUT_DIR, "summary.txt")
|
||||
|
||||
with open(split_path, "w", encoding="ascii") as f:
|
||||
f.write("discrete\n")
|
||||
f.write(",".join(discrete) + "\n")
|
||||
f.write("continuous\n")
|
||||
f.write(",".join(continuous) + "\n")
|
||||
if unknown:
|
||||
f.write("unknown\n")
|
||||
f.write(",".join(unknown) + "\n")
|
||||
|
||||
with open(summary_path, "w", encoding="ascii") as f:
|
||||
f.write("rows_sampled: %d\n" % rows)
|
||||
f.write("columns_total: %d\n" % len(cols))
|
||||
f.write("continuous: %d\n" % len(continuous))
|
||||
f.write("discrete: %d\n" % len(discrete))
|
||||
f.write("unknown: %d\n" % len(unknown))
|
||||
f.write("data_path: %s\n" % DATA_PATH)
|
||||
|
||||
|
||||
def main():
|
||||
if not os.path.exists(DATA_PATH):
|
||||
raise SystemExit("missing data file: %s" % DATA_PATH)
|
||||
cols, continuous, discrete, unknown, rows = analyze(DATA_PATH, MAX_ROWS)
|
||||
write_results(cols, continuous, discrete, unknown, rows)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user