连续型特征在时许相关性上的不足
This commit is contained in:
@@ -5,7 +5,7 @@ import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from data_utils import compute_cont_stats, build_disc_stats, load_split
|
||||
from data_utils import compute_cont_stats, build_disc_stats, load_split, choose_cont_transforms
|
||||
from platform_utils import safe_path, ensure_dir
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
@@ -27,20 +27,25 @@ def main(max_rows: Optional[int] = None):
|
||||
raise SystemExit("no train files found under %s" % str(DATA_GLOB))
|
||||
data_paths = [safe_path(p) for p in data_paths]
|
||||
|
||||
mean, std, vmin, vmax, int_like, max_decimals = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows)
|
||||
transforms, _ = choose_cont_transforms(data_paths, cont_cols, max_rows=max_rows)
|
||||
cont_stats = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows, transforms=transforms)
|
||||
vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows)
|
||||
|
||||
ensure_dir(OUT_STATS.parent)
|
||||
with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{
|
||||
"mean": mean,
|
||||
"std": std,
|
||||
"min": vmin,
|
||||
"max": vmax,
|
||||
"int_like": int_like,
|
||||
"max_decimals": max_decimals,
|
||||
"max_rows": max_rows,
|
||||
"mean": cont_stats["mean"],
|
||||
"std": cont_stats["std"],
|
||||
"raw_mean": cont_stats["raw_mean"],
|
||||
"raw_std": cont_stats["raw_std"],
|
||||
"min": cont_stats["min"],
|
||||
"max": cont_stats["max"],
|
||||
"int_like": cont_stats["int_like"],
|
||||
"max_decimals": cont_stats["max_decimals"],
|
||||
"transform": cont_stats["transform"],
|
||||
"skew": cont_stats["skew"],
|
||||
"max_rows": cont_stats["max_rows"],
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
|
||||
Reference in New Issue
Block a user