update新结构

This commit is contained in:
2026-01-26 19:00:16 +08:00
parent f8edee9510
commit dd4c1e171f
9 changed files with 545 additions and 5 deletions

View File

@@ -11,6 +11,7 @@ import torch.nn.functional as F
from data_utils import load_split
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
from hybrid_diffusion import TemporalGRUGenerator
from platform_utils import resolve_device, safe_path, ensure_dir
BASE_DIR = Path(__file__).resolve().parent
@@ -59,6 +60,10 @@ def main():
model_use_feature_graph = bool(cfg.get("model_use_feature_graph", False))
feature_graph_scale = float(cfg.get("feature_graph_scale", 0.1))
feature_graph_dropout = float(cfg.get("feature_graph_dropout", 0.0))
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
split = load_split(str(SPLIT_PATH))
time_col = split.get("time_column", "time")
@@ -98,6 +103,20 @@ def main():
model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True))
model.eval()
temporal_model = None
if use_temporal_stage1:
temporal_model = TemporalGRUGenerator(
input_dim=len(cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_num_layers,
dropout=temporal_dropout,
).to(DEVICE)
temporal_path = BASE_DIR / "results" / "temporal.pt"
if not temporal_path.exists():
raise SystemExit(f"missing temporal model file: {temporal_path}")
temporal_model.load_state_dict(torch.load(str(temporal_path), map_location=DEVICE, weights_only=True))
temporal_model.eval()
betas = cosine_beta_schedule(timesteps).to(DEVICE)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
@@ -116,6 +135,10 @@ def main():
raise SystemExit("use_condition enabled but no files matched data_glob")
cond = torch.randint(0, cond_vocab_size, (batch_size,), device=DEVICE, dtype=torch.long)
trend = None
if temporal_model is not None:
trend = temporal_model.generate(batch_size, seq_len, DEVICE)
for t in reversed(range(timesteps)):
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
@@ -155,6 +178,8 @@ def main():
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(BATCH_SIZE, SEQ_LEN)
x_disc[:, :, i][mask] = sampled[mask]
if trend is not None:
x_cont = x_cont + trend
print("sampled_cont_shape", tuple(x_cont.shape))
print("sampled_disc_shape", tuple(x_disc.shape))