连续型特征在时许相关性上的不足
This commit is contained in:
@@ -111,6 +111,7 @@ def main():
|
||||
vmax = stats.get("max", {})
|
||||
int_like = stats.get("int_like", {})
|
||||
max_decimals = stats.get("max_decimals", {})
|
||||
transforms = stats.get("transform", {})
|
||||
|
||||
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
|
||||
vocab = vocab_json["vocab"]
|
||||
@@ -141,6 +142,13 @@ def main():
|
||||
model = HybridDiffusionModel(
|
||||
cont_dim=len(cont_cols),
|
||||
disc_vocab_sizes=vocab_sizes,
|
||||
time_dim=int(cfg.get("model_time_dim", 64)),
|
||||
hidden_dim=int(cfg.get("model_hidden_dim", 256)),
|
||||
num_layers=int(cfg.get("model_num_layers", 1)),
|
||||
dropout=float(cfg.get("model_dropout", 0.0)),
|
||||
ff_mult=int(cfg.get("model_ff_mult", 2)),
|
||||
pos_dim=int(cfg.get("model_pos_dim", 64)),
|
||||
use_pos_embed=bool(cfg.get("model_use_pos_embed", True)),
|
||||
cond_vocab_size=cond_vocab_size if use_condition else 0,
|
||||
cond_dim=int(cfg.get("cond_dim", 32)),
|
||||
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
|
||||
@@ -220,6 +228,9 @@ def main():
|
||||
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||
x_cont = x_cont * std_vec + mean_vec
|
||||
for i, c in enumerate(cont_cols):
|
||||
if transforms.get(c) == "log1p":
|
||||
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])
|
||||
# clamp to observed min/max per feature
|
||||
if vmin and vmax:
|
||||
for i, c in enumerate(cont_cols):
|
||||
@@ -246,8 +257,8 @@ def main():
|
||||
row["__cond_file_id"] = str(int(cond[b].item())) if cond is not None else "-1"
|
||||
if args.include_time and time_col in header:
|
||||
row[time_col] = str(row_index)
|
||||
for i, c in enumerate(cont_cols):
|
||||
val = float(x_cont[b, t, i])
|
||||
for i, c in enumerate(cont_cols):
|
||||
val = float(x_cont[b, t, i])
|
||||
if int_like.get(c, False):
|
||||
row[c] = str(int(round(val)))
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user