This commit is contained in:
2026-01-22 21:17:11 +08:00
parent 5a109f91ac
commit 178fb7441c
4 changed files with 102 additions and 12 deletions

View File

@@ -5,7 +5,7 @@ import json
from pathlib import Path
from typing import Optional
from data_utils import compute_cont_stats, build_vocab, load_split
from data_utils import compute_cont_stats, build_disc_stats, load_split
from platform_utils import safe_path, ensure_dir
BASE_DIR = Path(__file__).resolve().parent
@@ -27,15 +27,27 @@ def main(max_rows: Optional[int] = None):
raise SystemExit("no train files found under %s" % str(DATA_GLOB))
data_paths = [safe_path(p) for p in data_paths]
mean, std = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows)
vocab = build_vocab(data_paths, disc_cols, max_rows=max_rows)
mean, std, vmin, vmax, int_like, max_decimals = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows)
vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows)
ensure_dir(OUT_STATS.parent)
with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f:
json.dump({"mean": mean, "std": std, "max_rows": max_rows}, f, indent=2)
json.dump(
{
"mean": mean,
"std": std,
"min": vmin,
"max": vmax,
"int_like": int_like,
"max_decimals": max_decimals,
"max_rows": max_rows,
},
f,
indent=2,
)
with open(safe_path(OUT_VOCAB), "w", encoding="utf-8") as f:
json.dump({"vocab": vocab, "max_rows": max_rows}, f, indent=2)
json.dump({"vocab": vocab, "top_token": top_token, "max_rows": max_rows}, f, indent=2)
if __name__ == "__main__":