This commit is contained in:
MZ YANG
2026-02-10 18:24:59 +08:00
parent 7eee14ba2a
commit ccb33bf876
20 changed files with 174700 additions and 986 deletions

View File

@@ -12,7 +12,7 @@ from typing import Dict, List
import torch
import torch.nn.functional as F
from data_utils import load_split, inverse_quantile_transform, quantile_calibrate_to_real
from data_utils import load_split, normalize_cont, inverse_quantile_transform, quantile_calibrate_to_real
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, TemporalTransformerGenerator, cosine_beta_schedule
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
@@ -157,10 +157,15 @@ def main():
cont_post_calibrate = bool(cfg.get("cont_post_calibrate", False))
type1_cols = cfg.get("type1_features", []) or []
type5_cols = cfg.get("type5_features", []) or []
type4_cols = cfg.get("type4_features", []) or []
type1_cols = [c for c in type1_cols if c in cont_cols]
type5_cols = [c for c in type5_cols if c in cont_cols]
type4_cols = [c for c in type4_cols if c in cont_cols]
model_cont_cols = [c for c in cont_cols if c not in type1_cols and c not in type5_cols]
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
temporal_use_type1_cond = bool(cfg.get("temporal_use_type1_cond", False))
temporal_focus_type4 = bool(cfg.get("temporal_focus_type4", False))
temporal_exclude_type4 = bool(cfg.get("temporal_exclude_type4", False))
temporal_backbone = str(cfg.get("temporal_backbone", "gru"))
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
@@ -207,6 +212,22 @@ def main():
temporal_model = None
if use_temporal_stage1:
temporal_path = Path(args.model_path).with_name("temporal.pt")
if not temporal_path.exists():
raise SystemExit(f"missing temporal model file: {temporal_path}")
temporal_state = load_torch_state(str(temporal_path), device)
temporal_cond_dim = len(type1_cols) if (temporal_use_type1_cond and type1_cols) else 0
if isinstance(temporal_state, dict):
if "in_proj.weight" in temporal_state:
try:
temporal_cond_dim = max(0, int(temporal_state["in_proj.weight"].shape[1]) - len(model_cont_cols))
except Exception:
pass
elif "gru.weight_ih_l0" in temporal_state:
try:
temporal_cond_dim = max(0, int(temporal_state["gru.weight_ih_l0"].shape[1]) - len(model_cont_cols))
except Exception:
pass
if temporal_backbone == "transformer":
temporal_model = TemporalTransformerGenerator(
input_dim=len(model_cont_cols),
@@ -217,6 +238,7 @@ def main():
dropout=temporal_transformer_dropout,
pos_dim=temporal_pos_dim,
use_pos_embed=temporal_use_pos_embed,
cond_dim=temporal_cond_dim,
).to(device)
else:
temporal_model = TemporalGRUGenerator(
@@ -224,11 +246,9 @@ def main():
hidden_dim=temporal_hidden_dim,
num_layers=temporal_num_layers,
dropout=temporal_dropout,
cond_dim=temporal_cond_dim,
).to(device)
temporal_path = Path(args.model_path).with_name("temporal.pt")
if not temporal_path.exists():
raise SystemExit(f"missing temporal model file: {temporal_path}")
temporal_model.load_state_dict(load_torch_state(str(temporal_path), device))
temporal_model.load_state_dict(temporal_state)
temporal_model.eval()
betas = cosine_beta_schedule(args.timesteps).to(device)
@@ -279,13 +299,32 @@ def main():
for t, row in enumerate(seq):
for i, c in enumerate(type1_cols):
cond_cont[:, t, i] = float(row[c])
mean_vec = torch.tensor([mean[c] for c in type1_cols], dtype=cond_cont.dtype, device=device)
std_vec = torch.tensor([std[c] for c in type1_cols], dtype=cond_cont.dtype, device=device)
cond_cont = (cond_cont - mean_vec) / std_vec
cond_cont = normalize_cont(
cond_cont,
type1_cols,
mean,
std,
transforms=transforms,
quantile_probs=quantile_probs,
quantile_values=quantile_values,
use_quantile=use_quantile,
)
trend = None
if temporal_model is not None:
trend = temporal_model.generate(args.batch_size, args.seq_len, device)
trend = temporal_model.generate(args.batch_size, args.seq_len, device, cond_cont=cond_cont)
if temporal_focus_type4 and type4_cols:
type4_model_idx = [model_cont_cols.index(c) for c in type4_cols if c in model_cont_cols]
if type4_model_idx:
trend_mask = torch.zeros(1, 1, len(model_cont_cols), device=device, dtype=trend.dtype)
trend_mask[:, :, type4_model_idx] = 1.0
trend = trend * trend_mask
elif temporal_exclude_type4 and type4_cols:
type4_model_idx = [model_cont_cols.index(c) for c in type4_cols if c in model_cont_cols]
if type4_model_idx:
trend_mask = torch.ones(1, 1, len(model_cont_cols), device=device, dtype=trend.dtype)
trend_mask[:, :, type4_model_idx] = 0.0
trend = trend * trend_mask
for t in reversed(range(args.timesteps)):
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)