This commit is contained in:
2026-01-26 22:17:35 +08:00
parent 2e273fb8a2
commit e88b1cab91
9 changed files with 447 additions and 4 deletions

View File

@@ -13,7 +13,7 @@ import torch
import torch.nn.functional as F
from data_utils import load_split
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
@@ -140,6 +140,10 @@ def main():
raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob)
cont_target = str(cfg.get("cont_target", "eps"))
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
model = HybridDiffusionModel(
cont_dim=len(cont_cols),
@@ -163,6 +167,20 @@ def main():
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
model.eval()
temporal_model = None
if use_temporal_stage1:
temporal_model = TemporalGRUGenerator(
input_dim=len(cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_num_layers,
dropout=temporal_dropout,
).to(device)
temporal_path = Path(args.model_path).with_name("temporal.pt")
if not temporal_path.exists():
raise SystemExit(f"missing temporal model file: {temporal_path}")
temporal_model.load_state_dict(torch.load(temporal_path, map_location=device, weights_only=True))
temporal_model.eval()
betas = cosine_beta_schedule(args.timesteps).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
@@ -189,6 +207,10 @@ def main():
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
cond = cond_id
trend = None
if temporal_model is not None:
trend = temporal_model.generate(args.batch_size, args.seq_len, device)
for t in reversed(range(args.timesteps)):
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
@@ -225,6 +247,8 @@ def main():
)
x_disc[:, :, i][mask] = sampled[mask]
if trend is not None:
x_cont = x_cont + trend
# move to CPU for export
x_cont = x_cont.cpu()
x_disc = x_disc.cpu()