transformer
This commit is contained in:
@@ -144,6 +144,11 @@ def main():
|
||||
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
|
||||
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
|
||||
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
|
||||
backbone_type = str(cfg.get("backbone_type", "gru"))
|
||||
transformer_num_layers = int(cfg.get("transformer_num_layers", 2))
|
||||
transformer_nhead = int(cfg.get("transformer_nhead", 4))
|
||||
transformer_ff_dim = int(cfg.get("transformer_ff_dim", 512))
|
||||
transformer_dropout = float(cfg.get("transformer_dropout", 0.1))
|
||||
|
||||
model = HybridDiffusionModel(
|
||||
cont_dim=len(cont_cols),
|
||||
@@ -155,6 +160,11 @@ def main():
|
||||
ff_mult=int(cfg.get("model_ff_mult", 2)),
|
||||
pos_dim=int(cfg.get("model_pos_dim", 64)),
|
||||
use_pos_embed=bool(cfg.get("model_use_pos_embed", True)),
|
||||
backbone_type=backbone_type,
|
||||
transformer_num_layers=transformer_num_layers,
|
||||
transformer_nhead=transformer_nhead,
|
||||
transformer_ff_dim=transformer_ff_dim,
|
||||
transformer_dropout=transformer_dropout,
|
||||
cond_vocab_size=cond_vocab_size if use_condition else 0,
|
||||
cond_dim=int(cfg.get("cond_dim", 32)),
|
||||
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
|
||||
|
||||
Reference in New Issue
Block a user