连续型特征在时许相关性上的不足

This commit is contained in:
2026-01-23 15:06:52 +08:00
parent 0d17be9a1c
commit ff12324560
12 changed files with 1212 additions and 68 deletions

View File

@@ -47,6 +47,13 @@ def main():
cond_dim = int(cfg.get("cond_dim", 32))
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
eps_scale = float(cfg.get("eps_scale", 1.0))
model_time_dim = int(cfg.get("model_time_dim", 64))
model_hidden_dim = int(cfg.get("model_hidden_dim", 256))
model_num_layers = int(cfg.get("model_num_layers", 1))
model_dropout = float(cfg.get("model_dropout", 0.0))
model_ff_mult = int(cfg.get("model_ff_mult", 2))
model_pos_dim = int(cfg.get("model_pos_dim", 64))
model_use_pos = bool(cfg.get("model_use_pos_embed", True))
split = load_split(str(SPLIT_PATH))
time_col = split.get("time_column", "time")
@@ -67,6 +74,13 @@ def main():
model = HybridDiffusionModel(
cont_dim=len(cont_cols),
disc_vocab_sizes=vocab_sizes,
time_dim=model_time_dim,
hidden_dim=model_hidden_dim,
num_layers=model_num_layers,
dropout=model_dropout,
ff_mult=model_ff_mult,
pos_dim=model_pos_dim,
use_pos_embed=model_use_pos,
cond_vocab_size=cond_vocab_size,
cond_dim=cond_dim,
use_tanh_eps=use_tanh_eps,