transformer
This commit is contained in:
@@ -200,6 +200,11 @@ def main():
|
||||
ff_mult=int(config.get("model_ff_mult", 2)),
|
||||
pos_dim=int(config.get("model_pos_dim", 64)),
|
||||
use_pos_embed=bool(config.get("model_use_pos_embed", True)),
|
||||
backbone_type=str(config.get("backbone_type", "gru")),
|
||||
transformer_num_layers=int(config.get("transformer_num_layers", 4)),
|
||||
transformer_nhead=int(config.get("transformer_nhead", 8)),
|
||||
transformer_ff_dim=int(config.get("transformer_ff_dim", 2048)),
|
||||
transformer_dropout=float(config.get("transformer_dropout", 0.1)),
|
||||
cond_vocab_size=cond_vocab_size,
|
||||
cond_dim=int(config.get("cond_dim", 32)),
|
||||
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
|
||||
|
||||
Reference in New Issue
Block a user