This commit is contained in:
2026-01-22 21:17:11 +08:00
parent 5a109f91ac
commit 178fb7441c
4 changed files with 102 additions and 12 deletions

View File

@@ -66,6 +66,8 @@ python example/run_pipeline.py --device auto
- Continuous sampling is clipped in normalized space each step for stability.
- Optional conditioning by file id (`train*.csv.gz`) is enabled by default for multi-file training.
- Continuous head can be bounded with `tanh` via `use_tanh_eps` in config.
- Export now clamps continuous features to training min/max and preserves integer/decimal precision.
- `<UNK>` tokens are replaced by the most frequent token for each discrete column at export.
- The script only samples the first 5000 rows to stay fast.
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.
- `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU.

View File

@@ -27,20 +27,41 @@ def compute_cont_stats(
path: Union[str, List[str]],
cont_cols: List[str],
max_rows: Optional[int] = None,
) -> Tuple[Dict[str, float], Dict[str, float]]:
"""Streaming mean/std (Welford)."""
) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, bool], Dict[str, int]]:
"""Streaming mean/std (Welford) + min/max + int/precision metadata."""
count = 0
mean = {c: 0.0 for c in cont_cols}
m2 = {c: 0.0 for c in cont_cols}
vmin = {c: float("inf") for c in cont_cols}
vmax = {c: float("-inf") for c in cont_cols}
int_like = {c: True for c in cont_cols}
max_decimals = {c: 0 for c in cont_cols}
for i, row in enumerate(iter_rows(path)):
count += 1
for c in cont_cols:
x = float(row[c])
raw = row[c]
if raw is None or raw == "":
continue
x = float(raw)
delta = x - mean[c]
mean[c] += delta / count
delta2 = x - mean[c]
m2[c] += delta * delta2
if x < vmin[c]:
vmin[c] = x
if x > vmax[c]:
vmax[c] = x
if int_like[c] and abs(x - round(x)) > 1e-9:
int_like[c] = False
# track decimal places from raw string if possible
if "e" in raw or "E" in raw:
# scientific notation, skip precision inference
continue
if "." in raw:
dec = raw.split(".", 1)[1].rstrip("0")
if len(dec) > max_decimals[c]:
max_decimals[c] = len(dec)
if max_rows is not None and i + 1 >= max_rows:
break
@@ -51,7 +72,13 @@ def compute_cont_stats(
else:
var = 0.0
std[c] = var ** 0.5 if var > 0 else 1.0
return mean, std
# replace infs if column had no valid values
for c in cont_cols:
if vmin[c] == float("inf"):
vmin[c] = 0.0
if vmax[c] == float("-inf"):
vmax[c] = 0.0
return mean, std, vmin, vmax, int_like, max_decimals
def build_vocab(
@@ -75,6 +102,34 @@ def build_vocab(
return vocab
def build_disc_stats(
path: Union[str, List[str]],
disc_cols: List[str],
max_rows: Optional[int] = None,
) -> Tuple[Dict[str, Dict[str, int]], Dict[str, str]]:
counts = {c: {} for c in disc_cols}
for i, row in enumerate(iter_rows(path)):
for c in disc_cols:
val = row[c]
counts[c][val] = counts[c].get(val, 0) + 1
if max_rows is not None and i + 1 >= max_rows:
break
vocab = {}
top_token = {}
for c in disc_cols:
tokens = sorted(counts[c].keys())
if "<UNK>" not in tokens:
tokens.append("<UNK>")
vocab[c] = {tok: idx for idx, tok in enumerate(tokens)}
# most frequent token
if counts[c]:
top_token[c] = max(counts[c].items(), key=lambda kv: kv[1])[0]
else:
top_token[c] = "<UNK>"
return vocab, top_token
def normalize_cont(x, cont_cols: List[str], mean: Dict[str, float], std: Dict[str, float]):
import torch
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)

View File

@@ -107,8 +107,14 @@ def main():
stats = load_stats(args.stats_path)
mean = stats["mean"]
std = stats["std"]
vmin = stats.get("min", {})
vmax = stats.get("max", {})
int_like = stats.get("int_like", {})
max_decimals = stats.get("max_decimals", {})
vocab = load_vocab(args.vocab_path)
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
vocab = vocab_json["vocab"]
top_token = vocab_json.get("top_token", {})
inv_vocab = build_inverse_vocab(vocab)
vocab_sizes = [len(vocab[c]) for c in disc_cols]
@@ -214,6 +220,13 @@ def main():
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
x_cont = x_cont * std_vec + mean_vec
# clamp to observed min/max per feature
if vmin and vmax:
for i, c in enumerate(cont_cols):
lo = vmin.get(c, None)
hi = vmax.get(c, None)
if lo is not None and hi is not None:
x_cont[:, :, i] = torch.clamp(x_cont[:, :, i], float(lo), float(hi))
header = read_header(data_path)
out_cols = [c for c in header if c != time_col or args.include_time]
@@ -234,10 +247,18 @@ def main():
if args.include_time and time_col in header:
row[time_col] = str(row_index)
for i, c in enumerate(cont_cols):
row[c] = ("%.6f" % float(x_cont[b, t, i]))
val = float(x_cont[b, t, i])
if int_like.get(c, False):
row[c] = str(int(round(val)))
else:
dec = int(max_decimals.get(c, 6))
fmt = ("%%.%df" % dec) if dec > 0 else "%.0f"
row[c] = (fmt % val)
for i, c in enumerate(disc_cols):
tok_idx = int(x_disc[b, t, i])
tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "0"
tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "<UNK>"
if tok == "<UNK>" and c in top_token:
tok = top_token[c]
row[c] = tok
writer.writerow(row)
row_index += 1

View File

@@ -5,7 +5,7 @@ import json
from pathlib import Path
from typing import Optional
from data_utils import compute_cont_stats, build_vocab, load_split
from data_utils import compute_cont_stats, build_disc_stats, load_split
from platform_utils import safe_path, ensure_dir
BASE_DIR = Path(__file__).resolve().parent
@@ -27,15 +27,27 @@ def main(max_rows: Optional[int] = None):
raise SystemExit("no train files found under %s" % str(DATA_GLOB))
data_paths = [safe_path(p) for p in data_paths]
mean, std = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows)
vocab = build_vocab(data_paths, disc_cols, max_rows=max_rows)
mean, std, vmin, vmax, int_like, max_decimals = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows)
vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows)
ensure_dir(OUT_STATS.parent)
with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f:
json.dump({"mean": mean, "std": std, "max_rows": max_rows}, f, indent=2)
json.dump(
{
"mean": mean,
"std": std,
"min": vmin,
"max": vmax,
"int_like": int_like,
"max_decimals": max_decimals,
"max_rows": max_rows,
},
f,
indent=2,
)
with open(safe_path(OUT_VOCAB), "w", encoding="utf-8") as f:
json.dump({"vocab": vocab, "max_rows": max_rows}, f, indent=2)
json.dump({"vocab": vocab, "top_token": top_token, "max_rows": max_rows}, f, indent=2)
if __name__ == "__main__":