This commit is contained in:
2026-01-22 20:42:10 +08:00
parent f37a8ce179
commit 382c756dfe
10 changed files with 310 additions and 55 deletions

View File

@@ -10,7 +10,7 @@ from platform_utils import safe_path, ensure_dir
BASE_DIR = Path(__file__).resolve().parent
REPO_DIR = BASE_DIR.parent.parent
DATA_PATH = REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"
DATA_GLOB = REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"
SPLIT_PATH = BASE_DIR / "feature_split.json"
OUT_STATS = BASE_DIR / "results" / "cont_stats.json"
OUT_VOCAB = BASE_DIR / "results" / "disc_vocab.json"
@@ -22,8 +22,13 @@ def main(max_rows: Optional[int] = None):
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
mean, std = compute_cont_stats(safe_path(DATA_PATH), cont_cols, max_rows=max_rows)
vocab = build_vocab(safe_path(DATA_PATH), disc_cols, max_rows=max_rows)
data_paths = sorted(Path(REPO_DIR / "dataset" / "hai" / "hai-21.03").glob("train*.csv.gz"))
if not data_paths:
raise SystemExit("no train files found under %s" % str(DATA_GLOB))
data_paths = [safe_path(p) for p in data_paths]
mean, std = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows)
vocab = build_vocab(data_paths, disc_cols, max_rows=max_rows)
ensure_dir(OUT_STATS.parent)
with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f: