连续型特征在时许相关性上的不足
This commit is contained in:
@@ -49,8 +49,17 @@ DEFAULTS = {
|
||||
"use_condition": True,
|
||||
"condition_type": "file_id",
|
||||
"cond_dim": 32,
|
||||
"use_tanh_eps": True,
|
||||
"use_tanh_eps": False,
|
||||
"eps_scale": 1.0,
|
||||
"model_time_dim": 128,
|
||||
"model_hidden_dim": 512,
|
||||
"model_num_layers": 2,
|
||||
"model_dropout": 0.1,
|
||||
"model_ff_mult": 2,
|
||||
"model_pos_dim": 64,
|
||||
"model_use_pos_embed": True,
|
||||
"disc_mask_scale": 0.9,
|
||||
"shuffle_buffer": 256,
|
||||
}
|
||||
|
||||
|
||||
@@ -144,6 +153,7 @@ def main():
|
||||
stats = load_json(config["stats_path"])
|
||||
mean = stats["mean"]
|
||||
std = stats["std"]
|
||||
transforms = stats.get("transform", {})
|
||||
|
||||
vocab = load_json(config["vocab_path"])["vocab"]
|
||||
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||
@@ -164,6 +174,13 @@ def main():
|
||||
model = HybridDiffusionModel(
|
||||
cont_dim=len(cont_cols),
|
||||
disc_vocab_sizes=vocab_sizes,
|
||||
time_dim=int(config.get("model_time_dim", 64)),
|
||||
hidden_dim=int(config.get("model_hidden_dim", 256)),
|
||||
num_layers=int(config.get("model_num_layers", 1)),
|
||||
dropout=float(config.get("model_dropout", 0.0)),
|
||||
ff_mult=int(config.get("model_ff_mult", 2)),
|
||||
pos_dim=int(config.get("model_pos_dim", 64)),
|
||||
use_pos_embed=bool(config.get("model_use_pos_embed", True)),
|
||||
cond_vocab_size=cond_vocab_size,
|
||||
cond_dim=int(config.get("cond_dim", 32)),
|
||||
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
|
||||
@@ -198,6 +215,8 @@ def main():
|
||||
seq_len=int(config["seq_len"]),
|
||||
max_batches=int(config["max_batches"]),
|
||||
return_file_id=use_condition,
|
||||
transforms=transforms,
|
||||
shuffle_buffer=int(config.get("shuffle_buffer", 0)),
|
||||
)
|
||||
):
|
||||
if use_condition:
|
||||
@@ -215,7 +234,13 @@ def main():
|
||||
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
|
||||
|
||||
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
||||
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"]))
|
||||
x_disc_t, mask = q_sample_discrete(
|
||||
x_disc,
|
||||
t,
|
||||
mask_tokens,
|
||||
int(config["timesteps"]),
|
||||
mask_scale=float(config.get("disc_mask_scale", 1.0)),
|
||||
)
|
||||
|
||||
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user