From 178fb7441c58262efd7935332f952b5e0a98e784 Mon Sep 17 00:00:00 2001 From: MingzheYang Date: Thu, 22 Jan 2026 21:17:11 +0800 Subject: [PATCH] update --- example/README.md | 2 ++ example/data_utils.py | 63 ++++++++++++++++++++++++++++++++++++--- example/export_samples.py | 27 +++++++++++++++-- example/prepare_data.py | 22 ++++++++++---- 4 files changed, 102 insertions(+), 12 deletions(-) diff --git a/example/README.md b/example/README.md index 3199277..32ffe81 100644 --- a/example/README.md +++ b/example/README.md @@ -66,6 +66,8 @@ python example/run_pipeline.py --device auto - Continuous sampling is clipped in normalized space each step for stability. - Optional conditioning by file id (`train*.csv.gz`) is enabled by default for multi-file training. - Continuous head can be bounded with `tanh` via `use_tanh_eps` in config. +- Export now clamps continuous features to training min/max and preserves integer/decimal precision. +- `` tokens are replaced by the most frequent token for each discrete column at export. - The script only samples the first 5000 rows to stay fast. - `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it. - `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU. diff --git a/example/data_utils.py b/example/data_utils.py index 333ee45..a38b221 100755 --- a/example/data_utils.py +++ b/example/data_utils.py @@ -27,20 +27,41 @@ def compute_cont_stats( path: Union[str, List[str]], cont_cols: List[str], max_rows: Optional[int] = None, -) -> Tuple[Dict[str, float], Dict[str, float]]: - """Streaming mean/std (Welford).""" +) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, bool], Dict[str, int]]: + """Streaming mean/std (Welford) + min/max + int/precision metadata.""" count = 0 mean = {c: 0.0 for c in cont_cols} m2 = {c: 0.0 for c in cont_cols} + vmin = {c: float("inf") for c in cont_cols} + vmax = {c: float("-inf") for c in cont_cols} + int_like = {c: True for c in cont_cols} + max_decimals = {c: 0 for c in cont_cols} for i, row in enumerate(iter_rows(path)): count += 1 for c in cont_cols: - x = float(row[c]) + raw = row[c] + if raw is None or raw == "": + continue + x = float(raw) delta = x - mean[c] mean[c] += delta / count delta2 = x - mean[c] m2[c] += delta * delta2 + if x < vmin[c]: + vmin[c] = x + if x > vmax[c]: + vmax[c] = x + if int_like[c] and abs(x - round(x)) > 1e-9: + int_like[c] = False + # track decimal places from raw string if possible + if "e" in raw or "E" in raw: + # scientific notation, skip precision inference + continue + if "." in raw: + dec = raw.split(".", 1)[1].rstrip("0") + if len(dec) > max_decimals[c]: + max_decimals[c] = len(dec) if max_rows is not None and i + 1 >= max_rows: break @@ -51,7 +72,13 @@ def compute_cont_stats( else: var = 0.0 std[c] = var ** 0.5 if var > 0 else 1.0 - return mean, std + # replace infs if column had no valid values + for c in cont_cols: + if vmin[c] == float("inf"): + vmin[c] = 0.0 + if vmax[c] == float("-inf"): + vmax[c] = 0.0 + return mean, std, vmin, vmax, int_like, max_decimals def build_vocab( @@ -75,6 +102,34 @@ def build_vocab( return vocab +def build_disc_stats( + path: Union[str, List[str]], + disc_cols: List[str], + max_rows: Optional[int] = None, +) -> Tuple[Dict[str, Dict[str, int]], Dict[str, str]]: + counts = {c: {} for c in disc_cols} + for i, row in enumerate(iter_rows(path)): + for c in disc_cols: + val = row[c] + counts[c][val] = counts[c].get(val, 0) + 1 + if max_rows is not None and i + 1 >= max_rows: + break + + vocab = {} + top_token = {} + for c in disc_cols: + tokens = sorted(counts[c].keys()) + if "" not in tokens: + tokens.append("") + vocab[c] = {tok: idx for idx, tok in enumerate(tokens)} + # most frequent token + if counts[c]: + top_token[c] = max(counts[c].items(), key=lambda kv: kv[1])[0] + else: + top_token[c] = "" + return vocab, top_token + + def normalize_cont(x, cont_cols: List[str], mean: Dict[str, float], std: Dict[str, float]): import torch mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device) diff --git a/example/export_samples.py b/example/export_samples.py index 89ad9bb..3bfed15 100644 --- a/example/export_samples.py +++ b/example/export_samples.py @@ -107,8 +107,14 @@ def main(): stats = load_stats(args.stats_path) mean = stats["mean"] std = stats["std"] + vmin = stats.get("min", {}) + vmax = stats.get("max", {}) + int_like = stats.get("int_like", {}) + max_decimals = stats.get("max_decimals", {}) - vocab = load_vocab(args.vocab_path) + vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8")) + vocab = vocab_json["vocab"] + top_token = vocab_json.get("top_token", {}) inv_vocab = build_inverse_vocab(vocab) vocab_sizes = [len(vocab[c]) for c in disc_cols] @@ -214,6 +220,13 @@ def main(): mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype) std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype) x_cont = x_cont * std_vec + mean_vec + # clamp to observed min/max per feature + if vmin and vmax: + for i, c in enumerate(cont_cols): + lo = vmin.get(c, None) + hi = vmax.get(c, None) + if lo is not None and hi is not None: + x_cont[:, :, i] = torch.clamp(x_cont[:, :, i], float(lo), float(hi)) header = read_header(data_path) out_cols = [c for c in header if c != time_col or args.include_time] @@ -234,10 +247,18 @@ def main(): if args.include_time and time_col in header: row[time_col] = str(row_index) for i, c in enumerate(cont_cols): - row[c] = ("%.6f" % float(x_cont[b, t, i])) + val = float(x_cont[b, t, i]) + if int_like.get(c, False): + row[c] = str(int(round(val))) + else: + dec = int(max_decimals.get(c, 6)) + fmt = ("%%.%df" % dec) if dec > 0 else "%.0f" + row[c] = (fmt % val) for i, c in enumerate(disc_cols): tok_idx = int(x_disc[b, t, i]) - tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "0" + tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "" + if tok == "" and c in top_token: + tok = top_token[c] row[c] = tok writer.writerow(row) row_index += 1 diff --git a/example/prepare_data.py b/example/prepare_data.py index d28c227..26eac22 100755 --- a/example/prepare_data.py +++ b/example/prepare_data.py @@ -5,7 +5,7 @@ import json from pathlib import Path from typing import Optional -from data_utils import compute_cont_stats, build_vocab, load_split +from data_utils import compute_cont_stats, build_disc_stats, load_split from platform_utils import safe_path, ensure_dir BASE_DIR = Path(__file__).resolve().parent @@ -27,15 +27,27 @@ def main(max_rows: Optional[int] = None): raise SystemExit("no train files found under %s" % str(DATA_GLOB)) data_paths = [safe_path(p) for p in data_paths] - mean, std = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows) - vocab = build_vocab(data_paths, disc_cols, max_rows=max_rows) + mean, std, vmin, vmax, int_like, max_decimals = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows) + vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows) ensure_dir(OUT_STATS.parent) with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f: - json.dump({"mean": mean, "std": std, "max_rows": max_rows}, f, indent=2) + json.dump( + { + "mean": mean, + "std": std, + "min": vmin, + "max": vmax, + "int_like": int_like, + "max_decimals": max_decimals, + "max_rows": max_rows, + }, + f, + indent=2, + ) with open(safe_path(OUT_VOCAB), "w", encoding="utf-8") as f: - json.dump({"vocab": vocab, "max_rows": max_rows}, f, indent=2) + json.dump({"vocab": vocab, "top_token": top_token, "max_rows": max_rows}, f, indent=2) if __name__ == "__main__":