update
This commit is contained in:
@@ -35,6 +35,8 @@
|
|||||||
"disc_mask_scale": 0.9,
|
"disc_mask_scale": 0.9,
|
||||||
"cont_loss_weighting": "inv_std",
|
"cont_loss_weighting": "inv_std",
|
||||||
"cont_loss_eps": 1e-6,
|
"cont_loss_eps": 1e-6,
|
||||||
|
"cont_target": "x0",
|
||||||
|
"cont_clamp_x0": 5.0,
|
||||||
"shuffle_buffer": 256,
|
"shuffle_buffer": 256,
|
||||||
"sample_batch_size": 8,
|
"sample_batch_size": 8,
|
||||||
"sample_seq_len": 128
|
"sample_seq_len": 128
|
||||||
|
|||||||
@@ -112,6 +112,8 @@ def main():
|
|||||||
int_like = stats.get("int_like", {})
|
int_like = stats.get("int_like", {})
|
||||||
max_decimals = stats.get("max_decimals", {})
|
max_decimals = stats.get("max_decimals", {})
|
||||||
transforms = stats.get("transform", {})
|
transforms = stats.get("transform", {})
|
||||||
|
cont_target = str(cfg.get("cont_target", "eps"))
|
||||||
|
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
|
||||||
|
|
||||||
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
|
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
|
||||||
vocab = vocab_json["vocab"]
|
vocab = vocab_json["vocab"]
|
||||||
@@ -191,6 +193,12 @@ def main():
|
|||||||
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
|
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
|
||||||
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
||||||
|
|
||||||
|
if cont_target == "x0":
|
||||||
|
x0_pred = eps_pred
|
||||||
|
if cont_clamp_x0 > 0:
|
||||||
|
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
|
||||||
|
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
|
||||||
|
|
||||||
a_t = alphas[t]
|
a_t = alphas[t]
|
||||||
a_bar_t = alphas_cumprod[t]
|
a_bar_t = alphas_cumprod[t]
|
||||||
coef1 = 1.0 / torch.sqrt(a_t)
|
coef1 = 1.0 / torch.sqrt(a_t)
|
||||||
|
|||||||
@@ -47,6 +47,8 @@ def main():
|
|||||||
cond_dim = int(cfg.get("cond_dim", 32))
|
cond_dim = int(cfg.get("cond_dim", 32))
|
||||||
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
|
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
|
||||||
eps_scale = float(cfg.get("eps_scale", 1.0))
|
eps_scale = float(cfg.get("eps_scale", 1.0))
|
||||||
|
cont_target = str(cfg.get("cont_target", "eps"))
|
||||||
|
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
|
||||||
model_time_dim = int(cfg.get("model_time_dim", 64))
|
model_time_dim = int(cfg.get("model_time_dim", 64))
|
||||||
model_hidden_dim = int(cfg.get("model_hidden_dim", 256))
|
model_hidden_dim = int(cfg.get("model_hidden_dim", 256))
|
||||||
model_num_layers = int(cfg.get("model_num_layers", 1))
|
model_num_layers = int(cfg.get("model_num_layers", 1))
|
||||||
@@ -112,6 +114,12 @@ def main():
|
|||||||
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
|
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
|
||||||
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
||||||
|
|
||||||
|
if cont_target == "x0":
|
||||||
|
x0_pred = eps_pred
|
||||||
|
if cont_clamp_x0 > 0:
|
||||||
|
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
|
||||||
|
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
|
||||||
|
|
||||||
# Continuous reverse step (DDPM): x_{t-1} mean
|
# Continuous reverse step (DDPM): x_{t-1} mean
|
||||||
a_t = alphas[t]
|
a_t = alphas[t]
|
||||||
a_bar_t = alphas_cumprod[t]
|
a_bar_t = alphas_cumprod[t]
|
||||||
|
|||||||
@@ -62,6 +62,8 @@ DEFAULTS = {
|
|||||||
"shuffle_buffer": 256,
|
"shuffle_buffer": 256,
|
||||||
"cont_loss_weighting": "none", # none | inv_std
|
"cont_loss_weighting": "none", # none | inv_std
|
||||||
"cont_loss_eps": 1e-6,
|
"cont_loss_eps": 1e-6,
|
||||||
|
"cont_target": "eps", # eps | x0
|
||||||
|
"cont_clamp_x0": 0.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -249,15 +251,24 @@ def main():
|
|||||||
|
|
||||||
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)
|
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)
|
||||||
|
|
||||||
|
cont_target = str(config.get("cont_target", "eps"))
|
||||||
|
if cont_target == "x0":
|
||||||
|
x0_target = x_cont
|
||||||
|
if float(config.get("cont_clamp_x0", 0.0)) > 0:
|
||||||
|
x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"]))
|
||||||
|
loss_base = (eps_pred - x0_target) ** 2
|
||||||
|
else:
|
||||||
|
loss_base = (eps_pred - noise) ** 2
|
||||||
|
|
||||||
if config.get("cont_loss_weighting") == "inv_std":
|
if config.get("cont_loss_weighting") == "inv_std":
|
||||||
weights = torch.tensor(
|
weights = torch.tensor(
|
||||||
[1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in cont_cols],
|
[1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in cont_cols],
|
||||||
device=device,
|
device=device,
|
||||||
dtype=eps_pred.dtype,
|
dtype=eps_pred.dtype,
|
||||||
).view(1, 1, -1)
|
).view(1, 1, -1)
|
||||||
loss_cont = ((eps_pred - noise) ** 2 * weights).mean()
|
loss_cont = (loss_base * weights).mean()
|
||||||
else:
|
else:
|
||||||
loss_cont = F.mse_loss(eps_pred, noise)
|
loss_cont = loss_base.mean()
|
||||||
loss_disc = 0.0
|
loss_disc = 0.0
|
||||||
loss_disc_count = 0
|
loss_disc_count = 0
|
||||||
for i, logit in enumerate(logits):
|
for i, logit in enumerate(logits):
|
||||||
|
|||||||
Reference in New Issue
Block a user