update
This commit is contained in:
@@ -66,6 +66,8 @@ python example/run_pipeline.py --device auto
|
||||
- Continuous sampling is clipped in normalized space each step for stability.
|
||||
- Optional conditioning by file id (`train*.csv.gz`) is enabled by default for multi-file training.
|
||||
- Continuous head can be bounded with `tanh` via `use_tanh_eps` in config.
|
||||
- Export now clamps continuous features to training min/max and preserves integer/decimal precision.
|
||||
- `<UNK>` tokens are replaced by the most frequent token for each discrete column at export.
|
||||
- The script only samples the first 5000 rows to stay fast.
|
||||
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.
|
||||
- `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU.
|
||||
|
||||
@@ -27,20 +27,41 @@ def compute_cont_stats(
|
||||
path: Union[str, List[str]],
|
||||
cont_cols: List[str],
|
||||
max_rows: Optional[int] = None,
|
||||
) -> Tuple[Dict[str, float], Dict[str, float]]:
|
||||
"""Streaming mean/std (Welford)."""
|
||||
) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, bool], Dict[str, int]]:
|
||||
"""Streaming mean/std (Welford) + min/max + int/precision metadata."""
|
||||
count = 0
|
||||
mean = {c: 0.0 for c in cont_cols}
|
||||
m2 = {c: 0.0 for c in cont_cols}
|
||||
vmin = {c: float("inf") for c in cont_cols}
|
||||
vmax = {c: float("-inf") for c in cont_cols}
|
||||
int_like = {c: True for c in cont_cols}
|
||||
max_decimals = {c: 0 for c in cont_cols}
|
||||
|
||||
for i, row in enumerate(iter_rows(path)):
|
||||
count += 1
|
||||
for c in cont_cols:
|
||||
x = float(row[c])
|
||||
raw = row[c]
|
||||
if raw is None or raw == "":
|
||||
continue
|
||||
x = float(raw)
|
||||
delta = x - mean[c]
|
||||
mean[c] += delta / count
|
||||
delta2 = x - mean[c]
|
||||
m2[c] += delta * delta2
|
||||
if x < vmin[c]:
|
||||
vmin[c] = x
|
||||
if x > vmax[c]:
|
||||
vmax[c] = x
|
||||
if int_like[c] and abs(x - round(x)) > 1e-9:
|
||||
int_like[c] = False
|
||||
# track decimal places from raw string if possible
|
||||
if "e" in raw or "E" in raw:
|
||||
# scientific notation, skip precision inference
|
||||
continue
|
||||
if "." in raw:
|
||||
dec = raw.split(".", 1)[1].rstrip("0")
|
||||
if len(dec) > max_decimals[c]:
|
||||
max_decimals[c] = len(dec)
|
||||
if max_rows is not None and i + 1 >= max_rows:
|
||||
break
|
||||
|
||||
@@ -51,7 +72,13 @@ def compute_cont_stats(
|
||||
else:
|
||||
var = 0.0
|
||||
std[c] = var ** 0.5 if var > 0 else 1.0
|
||||
return mean, std
|
||||
# replace infs if column had no valid values
|
||||
for c in cont_cols:
|
||||
if vmin[c] == float("inf"):
|
||||
vmin[c] = 0.0
|
||||
if vmax[c] == float("-inf"):
|
||||
vmax[c] = 0.0
|
||||
return mean, std, vmin, vmax, int_like, max_decimals
|
||||
|
||||
|
||||
def build_vocab(
|
||||
@@ -75,6 +102,34 @@ def build_vocab(
|
||||
return vocab
|
||||
|
||||
|
||||
def build_disc_stats(
|
||||
path: Union[str, List[str]],
|
||||
disc_cols: List[str],
|
||||
max_rows: Optional[int] = None,
|
||||
) -> Tuple[Dict[str, Dict[str, int]], Dict[str, str]]:
|
||||
counts = {c: {} for c in disc_cols}
|
||||
for i, row in enumerate(iter_rows(path)):
|
||||
for c in disc_cols:
|
||||
val = row[c]
|
||||
counts[c][val] = counts[c].get(val, 0) + 1
|
||||
if max_rows is not None and i + 1 >= max_rows:
|
||||
break
|
||||
|
||||
vocab = {}
|
||||
top_token = {}
|
||||
for c in disc_cols:
|
||||
tokens = sorted(counts[c].keys())
|
||||
if "<UNK>" not in tokens:
|
||||
tokens.append("<UNK>")
|
||||
vocab[c] = {tok: idx for idx, tok in enumerate(tokens)}
|
||||
# most frequent token
|
||||
if counts[c]:
|
||||
top_token[c] = max(counts[c].items(), key=lambda kv: kv[1])[0]
|
||||
else:
|
||||
top_token[c] = "<UNK>"
|
||||
return vocab, top_token
|
||||
|
||||
|
||||
def normalize_cont(x, cont_cols: List[str], mean: Dict[str, float], std: Dict[str, float]):
|
||||
import torch
|
||||
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)
|
||||
|
||||
@@ -107,8 +107,14 @@ def main():
|
||||
stats = load_stats(args.stats_path)
|
||||
mean = stats["mean"]
|
||||
std = stats["std"]
|
||||
vmin = stats.get("min", {})
|
||||
vmax = stats.get("max", {})
|
||||
int_like = stats.get("int_like", {})
|
||||
max_decimals = stats.get("max_decimals", {})
|
||||
|
||||
vocab = load_vocab(args.vocab_path)
|
||||
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
|
||||
vocab = vocab_json["vocab"]
|
||||
top_token = vocab_json.get("top_token", {})
|
||||
inv_vocab = build_inverse_vocab(vocab)
|
||||
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||
|
||||
@@ -214,6 +220,13 @@ def main():
|
||||
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||
x_cont = x_cont * std_vec + mean_vec
|
||||
# clamp to observed min/max per feature
|
||||
if vmin and vmax:
|
||||
for i, c in enumerate(cont_cols):
|
||||
lo = vmin.get(c, None)
|
||||
hi = vmax.get(c, None)
|
||||
if lo is not None and hi is not None:
|
||||
x_cont[:, :, i] = torch.clamp(x_cont[:, :, i], float(lo), float(hi))
|
||||
|
||||
header = read_header(data_path)
|
||||
out_cols = [c for c in header if c != time_col or args.include_time]
|
||||
@@ -234,10 +247,18 @@ def main():
|
||||
if args.include_time and time_col in header:
|
||||
row[time_col] = str(row_index)
|
||||
for i, c in enumerate(cont_cols):
|
||||
row[c] = ("%.6f" % float(x_cont[b, t, i]))
|
||||
val = float(x_cont[b, t, i])
|
||||
if int_like.get(c, False):
|
||||
row[c] = str(int(round(val)))
|
||||
else:
|
||||
dec = int(max_decimals.get(c, 6))
|
||||
fmt = ("%%.%df" % dec) if dec > 0 else "%.0f"
|
||||
row[c] = (fmt % val)
|
||||
for i, c in enumerate(disc_cols):
|
||||
tok_idx = int(x_disc[b, t, i])
|
||||
tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "0"
|
||||
tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "<UNK>"
|
||||
if tok == "<UNK>" and c in top_token:
|
||||
tok = top_token[c]
|
||||
row[c] = tok
|
||||
writer.writerow(row)
|
||||
row_index += 1
|
||||
|
||||
@@ -5,7 +5,7 @@ import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from data_utils import compute_cont_stats, build_vocab, load_split
|
||||
from data_utils import compute_cont_stats, build_disc_stats, load_split
|
||||
from platform_utils import safe_path, ensure_dir
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
@@ -27,15 +27,27 @@ def main(max_rows: Optional[int] = None):
|
||||
raise SystemExit("no train files found under %s" % str(DATA_GLOB))
|
||||
data_paths = [safe_path(p) for p in data_paths]
|
||||
|
||||
mean, std = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows)
|
||||
vocab = build_vocab(data_paths, disc_cols, max_rows=max_rows)
|
||||
mean, std, vmin, vmax, int_like, max_decimals = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows)
|
||||
vocab, top_token = build_disc_stats(data_paths, disc_cols, max_rows=max_rows)
|
||||
|
||||
ensure_dir(OUT_STATS.parent)
|
||||
with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f:
|
||||
json.dump({"mean": mean, "std": std, "max_rows": max_rows}, f, indent=2)
|
||||
json.dump(
|
||||
{
|
||||
"mean": mean,
|
||||
"std": std,
|
||||
"min": vmin,
|
||||
"max": vmax,
|
||||
"int_like": int_like,
|
||||
"max_decimals": max_decimals,
|
||||
"max_rows": max_rows,
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
with open(safe_path(OUT_VOCAB), "w", encoding="utf-8") as f:
|
||||
json.dump({"vocab": vocab, "max_rows": max_rows}, f, indent=2)
|
||||
json.dump({"vocab": vocab, "top_token": top_token, "max_rows": max_rows}, f, indent=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user