update2
This commit is contained in:
@@ -17,6 +17,14 @@ OUT_VOCAB = BASE_DIR / "results" / "disc_vocab.json"
|
||||
|
||||
|
||||
def main(max_rows: Optional[int] = None):
|
||||
config_path = BASE_DIR / "config.json"
|
||||
use_quantile = False
|
||||
quantile_bins = None
|
||||
if config_path.exists():
|
||||
cfg = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
use_quantile = bool(cfg.get("use_quantile_transform", False))
|
||||
quantile_bins = int(cfg.get("quantile_bins", 0)) if use_quantile else None
|
||||
|
||||
split = load_split(safe_path(SPLIT_PATH))
|
||||
time_col = split.get("time_column", "time")
|
||||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||
@@ -28,7 +36,13 @@ def main(max_rows: Optional[int] = None):
|
||||
data_paths = [safe_path(p) for p in data_paths]
|
||||
|
||||
transforms, _ = choose_cont_transforms(data_paths, cont_cols, max_rows=max_rows)
|
||||
cont_stats = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows, transforms=transforms)
|
||||
cont_stats = compute_cont_stats(
|
||||
data_paths,
|
||||
cont_cols,
|
||||
max_rows=max_rows,
|
||||
transforms=transforms,
|
||||
quantile_bins=quantile_bins,
|
||||
)
|
||||
vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows)
|
||||
|
||||
ensure_dir(OUT_STATS.parent)
|
||||
@@ -46,6 +60,8 @@ def main(max_rows: Optional[int] = None):
|
||||
"transform": cont_stats["transform"],
|
||||
"skew": cont_stats["skew"],
|
||||
"max_rows": cont_stats["max_rows"],
|
||||
"quantile_probs": cont_stats["quantile_probs"],
|
||||
"quantile_values": cont_stats["quantile_values"],
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
|
||||
Reference in New Issue
Block a user