transformer

This commit is contained in:
2026-01-27 00:41:42 +08:00
parent 65391910a2
commit 334db7082b
12 changed files with 175 additions and 11 deletions

View File

@@ -60,6 +60,11 @@ def main():
model_ff_mult = int(cfg.get("model_ff_mult", 2))
model_pos_dim = int(cfg.get("model_pos_dim", 64))
model_use_pos = bool(cfg.get("model_use_pos_embed", True))
backbone_type = str(cfg.get("backbone_type", "gru"))
transformer_num_layers = int(cfg.get("transformer_num_layers", 2))
transformer_nhead = int(cfg.get("transformer_nhead", 4))
transformer_ff_dim = int(cfg.get("transformer_ff_dim", 512))
transformer_dropout = float(cfg.get("transformer_dropout", 0.1))
split = load_split(str(SPLIT_PATH))
time_col = split.get("time_column", "time")
@@ -87,6 +92,11 @@ def main():
ff_mult=model_ff_mult,
pos_dim=model_pos_dim,
use_pos_embed=model_use_pos,
backbone_type=backbone_type,
transformer_num_layers=transformer_num_layers,
transformer_nhead=transformer_nhead,
transformer_ff_dim=transformer_ff_dim,
transformer_dropout=transformer_dropout,
cond_vocab_size=cond_vocab_size,
cond_dim=cond_dim,
use_tanh_eps=use_tanh_eps,