update新结构
This commit is contained in:
@@ -56,6 +56,9 @@ def main():
|
||||
model_ff_mult = int(cfg.get("model_ff_mult", 2))
|
||||
model_pos_dim = int(cfg.get("model_pos_dim", 64))
|
||||
model_use_pos = bool(cfg.get("model_use_pos_embed", True))
|
||||
model_use_feature_graph = bool(cfg.get("model_use_feature_graph", False))
|
||||
feature_graph_scale = float(cfg.get("feature_graph_scale", 0.1))
|
||||
feature_graph_dropout = float(cfg.get("feature_graph_dropout", 0.0))
|
||||
|
||||
split = load_split(str(SPLIT_PATH))
|
||||
time_col = split.get("time_column", "time")
|
||||
@@ -83,6 +86,9 @@ def main():
|
||||
ff_mult=model_ff_mult,
|
||||
pos_dim=model_pos_dim,
|
||||
use_pos_embed=model_use_pos,
|
||||
use_feature_graph=model_use_feature_graph,
|
||||
feature_graph_scale=feature_graph_scale,
|
||||
feature_graph_dropout=feature_graph_dropout,
|
||||
cond_vocab_size=cond_vocab_size,
|
||||
cond_dim=cond_dim,
|
||||
use_tanh_eps=use_tanh_eps,
|
||||
|
||||
Reference in New Issue
Block a user