update
This commit is contained in:
@@ -107,8 +107,14 @@ def main():
|
||||
stats = load_stats(args.stats_path)
|
||||
mean = stats["mean"]
|
||||
std = stats["std"]
|
||||
vmin = stats.get("min", {})
|
||||
vmax = stats.get("max", {})
|
||||
int_like = stats.get("int_like", {})
|
||||
max_decimals = stats.get("max_decimals", {})
|
||||
|
||||
vocab = load_vocab(args.vocab_path)
|
||||
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
|
||||
vocab = vocab_json["vocab"]
|
||||
top_token = vocab_json.get("top_token", {})
|
||||
inv_vocab = build_inverse_vocab(vocab)
|
||||
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||
|
||||
@@ -214,6 +220,13 @@ def main():
|
||||
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||
x_cont = x_cont * std_vec + mean_vec
|
||||
# clamp to observed min/max per feature
|
||||
if vmin and vmax:
|
||||
for i, c in enumerate(cont_cols):
|
||||
lo = vmin.get(c, None)
|
||||
hi = vmax.get(c, None)
|
||||
if lo is not None and hi is not None:
|
||||
x_cont[:, :, i] = torch.clamp(x_cont[:, :, i], float(lo), float(hi))
|
||||
|
||||
header = read_header(data_path)
|
||||
out_cols = [c for c in header if c != time_col or args.include_time]
|
||||
@@ -234,10 +247,18 @@ def main():
|
||||
if args.include_time and time_col in header:
|
||||
row[time_col] = str(row_index)
|
||||
for i, c in enumerate(cont_cols):
|
||||
row[c] = ("%.6f" % float(x_cont[b, t, i]))
|
||||
val = float(x_cont[b, t, i])
|
||||
if int_like.get(c, False):
|
||||
row[c] = str(int(round(val)))
|
||||
else:
|
||||
dec = int(max_decimals.get(c, 6))
|
||||
fmt = ("%%.%df" % dec) if dec > 0 else "%.0f"
|
||||
row[c] = (fmt % val)
|
||||
for i, c in enumerate(disc_cols):
|
||||
tok_idx = int(x_disc[b, t, i])
|
||||
tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "0"
|
||||
tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "<UNK>"
|
||||
if tok == "<UNK>" and c in top_token:
|
||||
tok = top_token[c]
|
||||
row[c] = tok
|
||||
writer.writerow(row)
|
||||
row_index += 1
|
||||
|
||||
Reference in New Issue
Block a user