Win and linux can run the code
This commit is contained in:
@@ -12,13 +12,13 @@ from pathlib import Path
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
REPO_DIR = BASE_DIR.parent.parent
|
||||
DATA_PATH = str(REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz")
|
||||
OUT_DIR = str(BASE_DIR / "results")
|
||||
DATA_PATH = REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"
|
||||
OUT_DIR = BASE_DIR / "results"
|
||||
MAX_ROWS = 5000
|
||||
|
||||
|
||||
def analyze(path: str, max_rows: int):
|
||||
with gzip.open(path, "rt", newline="") as f:
|
||||
with gzip.open(str(path), "rt", newline="") as f:
|
||||
reader = csv.reader(f)
|
||||
cols = next(reader)
|
||||
stats = {
|
||||
@@ -74,11 +74,11 @@ def analyze(path: str, max_rows: int):
|
||||
|
||||
|
||||
def write_results(cols, continuous, discrete, unknown, rows):
|
||||
os.makedirs(OUT_DIR, exist_ok=True)
|
||||
split_path = os.path.join(OUT_DIR, "feature_split.txt")
|
||||
summary_path = os.path.join(OUT_DIR, "summary.txt")
|
||||
os.makedirs(str(OUT_DIR), exist_ok=True)
|
||||
split_path = OUT_DIR / "feature_split.txt"
|
||||
summary_path = OUT_DIR / "summary.txt"
|
||||
|
||||
with open(split_path, "w", encoding="ascii") as f:
|
||||
with open(split_path, "w", encoding="utf-8") as f:
|
||||
f.write("discrete\n")
|
||||
f.write(",".join(discrete) + "\n")
|
||||
f.write("continuous\n")
|
||||
@@ -87,19 +87,19 @@ def write_results(cols, continuous, discrete, unknown, rows):
|
||||
f.write("unknown\n")
|
||||
f.write(",".join(unknown) + "\n")
|
||||
|
||||
with open(summary_path, "w", encoding="ascii") as f:
|
||||
with open(summary_path, "w", encoding="utf-8") as f:
|
||||
f.write("rows_sampled: %d\n" % rows)
|
||||
f.write("columns_total: %d\n" % len(cols))
|
||||
f.write("continuous: %d\n" % len(continuous))
|
||||
f.write("discrete: %d\n" % len(discrete))
|
||||
f.write("unknown: %d\n" % len(unknown))
|
||||
f.write("data_path: %s\n" % DATA_PATH)
|
||||
f.write("data_path: %s\n" % str(DATA_PATH))
|
||||
|
||||
|
||||
def main():
|
||||
if not os.path.exists(DATA_PATH):
|
||||
raise SystemExit("missing data file: %s" % DATA_PATH)
|
||||
cols, continuous, discrete, unknown, rows = analyze(DATA_PATH, MAX_ROWS)
|
||||
if not DATA_PATH.exists():
|
||||
raise SystemExit("missing data file: %s" % str(DATA_PATH))
|
||||
cols, continuous, discrete, unknown, rows = analyze(str(DATA_PATH), MAX_ROWS)
|
||||
write_results(cols, continuous, discrete, unknown, rows)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user