update
This commit is contained in:
@@ -68,6 +68,7 @@ python example/run_pipeline.py --device auto
|
|||||||
- Continuous head can be bounded with `tanh` via `use_tanh_eps` in config.
|
- Continuous head can be bounded with `tanh` via `use_tanh_eps` in config.
|
||||||
- Export now clamps continuous features to training min/max and preserves integer/decimal precision.
|
- Export now clamps continuous features to training min/max and preserves integer/decimal precision.
|
||||||
- `<UNK>` tokens are replaced by the most frequent token for each discrete column at export.
|
- `<UNK>` tokens are replaced by the most frequent token for each discrete column at export.
|
||||||
|
- Continuous prediction mode can be switched via `cont_pred` (`eps` or `x0`).
|
||||||
- The script only samples the first 5000 rows to stay fast.
|
- The script only samples the first 5000 rows to stay fast.
|
||||||
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.
|
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.
|
||||||
- `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU.
|
- `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU.
|
||||||
|
|||||||
@@ -25,6 +25,7 @@
|
|||||||
"cond_dim": 32,
|
"cond_dim": 32,
|
||||||
"use_tanh_eps": true,
|
"use_tanh_eps": true,
|
||||||
"eps_scale": 0.7,
|
"eps_scale": 0.7,
|
||||||
|
"cont_pred": "x0",
|
||||||
"sample_batch_size": 8,
|
"sample_batch_size": 8,
|
||||||
"sample_seq_len": 128
|
"sample_seq_len": 128
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -179,14 +179,20 @@ def main():
|
|||||||
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
|
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
|
||||||
cond = cond_id
|
cond = cond_id
|
||||||
|
|
||||||
|
cont_pred = str(cfg.get("cont_pred", "eps")).lower()
|
||||||
for t in reversed(range(args.timesteps)):
|
for t in reversed(range(args.timesteps)):
|
||||||
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
|
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
|
||||||
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
cont_pred_out, logits = model(x_cont, x_disc, t_batch, cond)
|
||||||
|
|
||||||
a_t = alphas[t]
|
a_t = alphas[t]
|
||||||
a_bar_t = alphas_cumprod[t]
|
a_bar_t = alphas_cumprod[t]
|
||||||
coef1 = 1.0 / torch.sqrt(a_t)
|
coef1 = 1.0 / torch.sqrt(a_t)
|
||||||
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
|
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
|
||||||
|
if cont_pred == "x0":
|
||||||
|
# eps = (x_t - sqrt(a_bar) * x0) / sqrt(1 - a_bar)
|
||||||
|
eps_pred = (x_cont - torch.sqrt(a_bar_t) * cont_pred_out) / torch.sqrt(1 - a_bar_t + 1e-8)
|
||||||
|
else:
|
||||||
|
eps_pred = cont_pred_out
|
||||||
mean_x = coef1 * (x_cont - coef2 * eps_pred)
|
mean_x = coef1 * (x_cont - coef2 * eps_pred)
|
||||||
if t > 0:
|
if t > 0:
|
||||||
noise = torch.randn_like(x_cont)
|
noise = torch.randn_like(x_cont)
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ def main():
|
|||||||
batch_size = int(cfg.get("sample_batch_size", cfg.get("batch_size", BATCH_SIZE)))
|
batch_size = int(cfg.get("sample_batch_size", cfg.get("batch_size", BATCH_SIZE)))
|
||||||
clip_k = float(cfg.get("clip_k", CLIP_K))
|
clip_k = float(cfg.get("clip_k", CLIP_K))
|
||||||
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
|
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
|
||||||
|
cont_pred = str(cfg.get("cont_pred", "eps")).lower()
|
||||||
cond_dim = int(cfg.get("cond_dim", 32))
|
cond_dim = int(cfg.get("cond_dim", 32))
|
||||||
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
|
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
|
||||||
eps_scale = float(cfg.get("eps_scale", 1.0))
|
eps_scale = float(cfg.get("eps_scale", 1.0))
|
||||||
@@ -96,13 +97,17 @@ def main():
|
|||||||
|
|
||||||
for t in reversed(range(timesteps)):
|
for t in reversed(range(timesteps)):
|
||||||
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
|
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
|
||||||
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
cont_pred_out, logits = model(x_cont, x_disc, t_batch, cond)
|
||||||
|
|
||||||
# Continuous reverse step (DDPM): x_{t-1} mean
|
# Continuous reverse step (DDPM): x_{t-1} mean
|
||||||
a_t = alphas[t]
|
a_t = alphas[t]
|
||||||
a_bar_t = alphas_cumprod[t]
|
a_bar_t = alphas_cumprod[t]
|
||||||
coef1 = 1.0 / torch.sqrt(a_t)
|
coef1 = 1.0 / torch.sqrt(a_t)
|
||||||
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
|
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
|
||||||
|
if cont_pred == "x0":
|
||||||
|
eps_pred = (x_cont - torch.sqrt(a_bar_t) * cont_pred_out) / torch.sqrt(1 - a_bar_t + 1e-8)
|
||||||
|
else:
|
||||||
|
eps_pred = cont_pred_out
|
||||||
mean = coef1 * (x_cont - coef2 * eps_pred)
|
mean = coef1 * (x_cont - coef2 * eps_pred)
|
||||||
if t > 0:
|
if t > 0:
|
||||||
noise = torch.randn_like(x_cont)
|
noise = torch.randn_like(x_cont)
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ DEFAULTS = {
|
|||||||
"cond_dim": 32,
|
"cond_dim": 32,
|
||||||
"use_tanh_eps": True,
|
"use_tanh_eps": True,
|
||||||
"eps_scale": 1.0,
|
"eps_scale": 1.0,
|
||||||
|
"cont_pred": "eps",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -161,6 +162,7 @@ def main():
|
|||||||
|
|
||||||
device = resolve_device(str(config["device"]))
|
device = resolve_device(str(config["device"]))
|
||||||
print("device", device)
|
print("device", device)
|
||||||
|
cont_pred = str(config.get("cont_pred", "eps")).lower()
|
||||||
model = HybridDiffusionModel(
|
model = HybridDiffusionModel(
|
||||||
cont_dim=len(cont_cols),
|
cont_dim=len(cont_cols),
|
||||||
disc_vocab_sizes=vocab_sizes,
|
disc_vocab_sizes=vocab_sizes,
|
||||||
@@ -217,9 +219,12 @@ def main():
|
|||||||
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
||||||
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"]))
|
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"]))
|
||||||
|
|
||||||
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)
|
cont_pred_out, logits = model(x_cont_t, x_disc_t, t, cond)
|
||||||
|
|
||||||
loss_cont = F.mse_loss(eps_pred, noise)
|
if cont_pred == "x0":
|
||||||
|
loss_cont = F.mse_loss(cont_pred_out, x_cont)
|
||||||
|
else:
|
||||||
|
loss_cont = F.mse_loss(cont_pred_out, noise)
|
||||||
loss_disc = 0.0
|
loss_disc = 0.0
|
||||||
loss_disc_count = 0
|
loss_disc_count = 0
|
||||||
for i, logit in enumerate(logits):
|
for i, logit in enumerate(logits):
|
||||||
|
|||||||
Reference in New Issue
Block a user