连续型特征在时许相关性上的不足

This commit is contained in:
2026-01-23 15:06:52 +08:00
parent 0d17be9a1c
commit ff12324560
12 changed files with 1212 additions and 68 deletions

View File

@@ -111,6 +111,7 @@ def main():
vmax = stats.get("max", {})
int_like = stats.get("int_like", {})
max_decimals = stats.get("max_decimals", {})
transforms = stats.get("transform", {})
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
vocab = vocab_json["vocab"]
@@ -141,6 +142,13 @@ def main():
model = HybridDiffusionModel(
cont_dim=len(cont_cols),
disc_vocab_sizes=vocab_sizes,
time_dim=int(cfg.get("model_time_dim", 64)),
hidden_dim=int(cfg.get("model_hidden_dim", 256)),
num_layers=int(cfg.get("model_num_layers", 1)),
dropout=float(cfg.get("model_dropout", 0.0)),
ff_mult=int(cfg.get("model_ff_mult", 2)),
pos_dim=int(cfg.get("model_pos_dim", 64)),
use_pos_embed=bool(cfg.get("model_use_pos_embed", True)),
cond_vocab_size=cond_vocab_size if use_condition else 0,
cond_dim=int(cfg.get("cond_dim", 32)),
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
@@ -220,6 +228,9 @@ def main():
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
x_cont = x_cont * std_vec + mean_vec
for i, c in enumerate(cont_cols):
if transforms.get(c) == "log1p":
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])
# clamp to observed min/max per feature
if vmin and vmax:
for i, c in enumerate(cont_cols):
@@ -246,8 +257,8 @@ def main():
row["__cond_file_id"] = str(int(cond[b].item())) if cond is not None else "-1"
if args.include_time and time_col in header:
row[time_col] = str(row_index)
for i, c in enumerate(cont_cols):
val = float(x_cont[b, t, i])
for i, c in enumerate(cont_cols):
val = float(x_cont[b, t, i])
if int_like.get(c, False):
row[c] = str(int(round(val)))
else: