This commit is contained in:
2026-01-22 21:17:11 +08:00
parent 5a109f91ac
commit 178fb7441c
4 changed files with 102 additions and 12 deletions

View File

@@ -107,8 +107,14 @@ def main():
stats = load_stats(args.stats_path)
mean = stats["mean"]
std = stats["std"]
vmin = stats.get("min", {})
vmax = stats.get("max", {})
int_like = stats.get("int_like", {})
max_decimals = stats.get("max_decimals", {})
vocab = load_vocab(args.vocab_path)
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
vocab = vocab_json["vocab"]
top_token = vocab_json.get("top_token", {})
inv_vocab = build_inverse_vocab(vocab)
vocab_sizes = [len(vocab[c]) for c in disc_cols]
@@ -214,6 +220,13 @@ def main():
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
x_cont = x_cont * std_vec + mean_vec
# clamp to observed min/max per feature
if vmin and vmax:
for i, c in enumerate(cont_cols):
lo = vmin.get(c, None)
hi = vmax.get(c, None)
if lo is not None and hi is not None:
x_cont[:, :, i] = torch.clamp(x_cont[:, :, i], float(lo), float(hi))
header = read_header(data_path)
out_cols = [c for c in header if c != time_col or args.include_time]
@@ -234,10 +247,18 @@ def main():
if args.include_time and time_col in header:
row[time_col] = str(row_index)
for i, c in enumerate(cont_cols):
row[c] = ("%.6f" % float(x_cont[b, t, i]))
val = float(x_cont[b, t, i])
if int_like.get(c, False):
row[c] = str(int(round(val)))
else:
dec = int(max_decimals.get(c, 6))
fmt = ("%%.%df" % dec) if dec > 0 else "%.0f"
row[c] = (fmt % val)
for i, c in enumerate(disc_cols):
tok_idx = int(x_disc[b, t, i])
tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "0"
tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "<UNK>"
if tok == "<UNK>" and c in top_token:
tok = top_token[c]
row[c] = tok
writer.writerow(row)
row_index += 1