transformer

This commit is contained in:
2026-01-27 00:41:42 +08:00
parent 65391910a2
commit 334db7082b
12 changed files with 175 additions and 11 deletions

View File

@@ -200,6 +200,11 @@ def main():
ff_mult=int(config.get("model_ff_mult", 2)),
pos_dim=int(config.get("model_pos_dim", 64)),
use_pos_embed=bool(config.get("model_use_pos_embed", True)),
backbone_type=str(config.get("backbone_type", "gru")),
transformer_num_layers=int(config.get("transformer_num_layers", 4)),
transformer_nhead=int(config.get("transformer_nhead", 8)),
transformer_ff_dim=int(config.get("transformer_ff_dim", 2048)),
transformer_dropout=float(config.get("transformer_dropout", 0.1)),
cond_vocab_size=cond_vocab_size,
cond_dim=int(config.get("cond_dim", 32)),
use_tanh_eps=bool(config.get("use_tanh_eps", False)),