This commit is contained in:
2026-01-26 22:17:35 +08:00
parent 2e273fb8a2
commit e88b1cab91
9 changed files with 447 additions and 4 deletions

View File

@@ -10,7 +10,7 @@ import torch
import torch.nn.functional as F
from data_utils import load_split
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
from platform_utils import resolve_device, safe_path, ensure_dir
BASE_DIR = Path(__file__).resolve().parent
@@ -47,6 +47,10 @@ def main():
cond_dim = int(cfg.get("cond_dim", 32))
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
eps_scale = float(cfg.get("eps_scale", 1.0))
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
cont_target = str(cfg.get("cont_target", "eps"))
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
model_time_dim = int(cfg.get("model_time_dim", 64))
@@ -92,6 +96,20 @@ def main():
model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True))
model.eval()
temporal_model = None
if use_temporal_stage1:
temporal_model = TemporalGRUGenerator(
input_dim=len(cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_num_layers,
dropout=temporal_dropout,
).to(DEVICE)
temporal_path = BASE_DIR / "results" / "temporal.pt"
if not temporal_path.exists():
raise SystemExit(f"missing temporal model file: {temporal_path}")
temporal_model.load_state_dict(torch.load(str(temporal_path), map_location=DEVICE, weights_only=True))
temporal_model.eval()
betas = cosine_beta_schedule(timesteps).to(DEVICE)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
@@ -110,6 +128,10 @@ def main():
raise SystemExit("use_condition enabled but no files matched data_glob")
cond = torch.randint(0, cond_vocab_size, (batch_size,), device=DEVICE, dtype=torch.long)
trend = None
if temporal_model is not None:
trend = temporal_model.generate(batch_size, seq_len, DEVICE)
for t in reversed(range(timesteps)):
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
@@ -146,6 +168,8 @@ def main():
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(BATCH_SIZE, SEQ_LEN)
x_disc[:, :, i][mask] = sampled[mask]
if trend is not None:
x_cont = x_cont + trend
print("sampled_cont_shape", tuple(x_cont.shape))
print("sampled_disc_shape", tuple(x_disc.shape))