From 0f741564601ad203be14ce99e11d6bf9176ac9bc Mon Sep 17 00:00:00 2001 From: MingzheYang Date: Fri, 23 Jan 2026 12:00:29 +0800 Subject: [PATCH] update --- example/README.md | 1 + example/config.json | 1 + example/export_samples.py | 8 +++++++- example/sample.py | 7 ++++++- example/train.py | 9 +++++++-- 5 files changed, 22 insertions(+), 4 deletions(-) diff --git a/example/README.md b/example/README.md index 32ffe81..88c7beb 100644 --- a/example/README.md +++ b/example/README.md @@ -68,6 +68,7 @@ python example/run_pipeline.py --device auto - Continuous head can be bounded with `tanh` via `use_tanh_eps` in config. - Export now clamps continuous features to training min/max and preserves integer/decimal precision. - `` tokens are replaced by the most frequent token for each discrete column at export. +- Continuous prediction mode can be switched via `cont_pred` (`eps` or `x0`). - The script only samples the first 5000 rows to stay fast. - `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it. - `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU. diff --git a/example/config.json b/example/config.json index 8fdfa9b..0146ea8 100644 --- a/example/config.json +++ b/example/config.json @@ -25,6 +25,7 @@ "cond_dim": 32, "use_tanh_eps": true, "eps_scale": 0.7, + "cont_pred": "x0", "sample_batch_size": 8, "sample_seq_len": 128 } diff --git a/example/export_samples.py b/example/export_samples.py index 016ad71..88b49aa 100644 --- a/example/export_samples.py +++ b/example/export_samples.py @@ -179,14 +179,20 @@ def main(): cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long) cond = cond_id + cont_pred = str(cfg.get("cont_pred", "eps")).lower() for t in reversed(range(args.timesteps)): t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long) - eps_pred, logits = model(x_cont, x_disc, t_batch, cond) + cont_pred_out, logits = model(x_cont, x_disc, t_batch, cond) a_t = alphas[t] a_bar_t = alphas_cumprod[t] coef1 = 1.0 / torch.sqrt(a_t) coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t) + if cont_pred == "x0": + # eps = (x_t - sqrt(a_bar) * x0) / sqrt(1 - a_bar) + eps_pred = (x_cont - torch.sqrt(a_bar_t) * cont_pred_out) / torch.sqrt(1 - a_bar_t + 1e-8) + else: + eps_pred = cont_pred_out mean_x = coef1 * (x_cont - coef2 * eps_pred) if t > 0: noise = torch.randn_like(x_cont) diff --git a/example/sample.py b/example/sample.py index 14a5863..d696f4c 100755 --- a/example/sample.py +++ b/example/sample.py @@ -44,6 +44,7 @@ def main(): batch_size = int(cfg.get("sample_batch_size", cfg.get("batch_size", BATCH_SIZE))) clip_k = float(cfg.get("clip_k", CLIP_K)) use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id" + cont_pred = str(cfg.get("cont_pred", "eps")).lower() cond_dim = int(cfg.get("cond_dim", 32)) use_tanh_eps = bool(cfg.get("use_tanh_eps", False)) eps_scale = float(cfg.get("eps_scale", 1.0)) @@ -96,13 +97,17 @@ def main(): for t in reversed(range(timesteps)): t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long) - eps_pred, logits = model(x_cont, x_disc, t_batch, cond) + cont_pred_out, logits = model(x_cont, x_disc, t_batch, cond) # Continuous reverse step (DDPM): x_{t-1} mean a_t = alphas[t] a_bar_t = alphas_cumprod[t] coef1 = 1.0 / torch.sqrt(a_t) coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t) + if cont_pred == "x0": + eps_pred = (x_cont - torch.sqrt(a_bar_t) * cont_pred_out) / torch.sqrt(1 - a_bar_t + 1e-8) + else: + eps_pred = cont_pred_out mean = coef1 * (x_cont - coef2 * eps_pred) if t > 0: noise = torch.randn_like(x_cont) diff --git a/example/train.py b/example/train.py index 13699c7..b21040d 100755 --- a/example/train.py +++ b/example/train.py @@ -51,6 +51,7 @@ DEFAULTS = { "cond_dim": 32, "use_tanh_eps": True, "eps_scale": 1.0, + "cont_pred": "eps", } @@ -161,6 +162,7 @@ def main(): device = resolve_device(str(config["device"])) print("device", device) + cont_pred = str(config.get("cont_pred", "eps")).lower() model = HybridDiffusionModel( cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes, @@ -217,9 +219,12 @@ def main(): mask_tokens = torch.tensor(vocab_sizes, device=device) x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"])) - eps_pred, logits = model(x_cont_t, x_disc_t, t, cond) + cont_pred_out, logits = model(x_cont_t, x_disc_t, t, cond) - loss_cont = F.mse_loss(eps_pred, noise) + if cont_pred == "x0": + loss_cont = F.mse_loss(cont_pred_out, x_cont) + else: + loss_cont = F.mse_loss(cont_pred_out, noise) loss_disc = 0.0 loss_disc_count = 0 for i, logit in enumerate(logits):