update
This commit is contained in:
@@ -12,7 +12,7 @@ from typing import Dict, List
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from data_utils import load_split, inverse_quantile_transform, quantile_calibrate_to_real
|
||||
from data_utils import load_split, normalize_cont, inverse_quantile_transform, quantile_calibrate_to_real
|
||||
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, TemporalTransformerGenerator, cosine_beta_schedule
|
||||
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
|
||||
|
||||
@@ -157,10 +157,15 @@ def main():
|
||||
cont_post_calibrate = bool(cfg.get("cont_post_calibrate", False))
|
||||
type1_cols = cfg.get("type1_features", []) or []
|
||||
type5_cols = cfg.get("type5_features", []) or []
|
||||
type4_cols = cfg.get("type4_features", []) or []
|
||||
type1_cols = [c for c in type1_cols if c in cont_cols]
|
||||
type5_cols = [c for c in type5_cols if c in cont_cols]
|
||||
type4_cols = [c for c in type4_cols if c in cont_cols]
|
||||
model_cont_cols = [c for c in cont_cols if c not in type1_cols and c not in type5_cols]
|
||||
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
|
||||
temporal_use_type1_cond = bool(cfg.get("temporal_use_type1_cond", False))
|
||||
temporal_focus_type4 = bool(cfg.get("temporal_focus_type4", False))
|
||||
temporal_exclude_type4 = bool(cfg.get("temporal_exclude_type4", False))
|
||||
temporal_backbone = str(cfg.get("temporal_backbone", "gru"))
|
||||
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
|
||||
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
|
||||
@@ -207,6 +212,22 @@ def main():
|
||||
|
||||
temporal_model = None
|
||||
if use_temporal_stage1:
|
||||
temporal_path = Path(args.model_path).with_name("temporal.pt")
|
||||
if not temporal_path.exists():
|
||||
raise SystemExit(f"missing temporal model file: {temporal_path}")
|
||||
temporal_state = load_torch_state(str(temporal_path), device)
|
||||
temporal_cond_dim = len(type1_cols) if (temporal_use_type1_cond and type1_cols) else 0
|
||||
if isinstance(temporal_state, dict):
|
||||
if "in_proj.weight" in temporal_state:
|
||||
try:
|
||||
temporal_cond_dim = max(0, int(temporal_state["in_proj.weight"].shape[1]) - len(model_cont_cols))
|
||||
except Exception:
|
||||
pass
|
||||
elif "gru.weight_ih_l0" in temporal_state:
|
||||
try:
|
||||
temporal_cond_dim = max(0, int(temporal_state["gru.weight_ih_l0"].shape[1]) - len(model_cont_cols))
|
||||
except Exception:
|
||||
pass
|
||||
if temporal_backbone == "transformer":
|
||||
temporal_model = TemporalTransformerGenerator(
|
||||
input_dim=len(model_cont_cols),
|
||||
@@ -217,6 +238,7 @@ def main():
|
||||
dropout=temporal_transformer_dropout,
|
||||
pos_dim=temporal_pos_dim,
|
||||
use_pos_embed=temporal_use_pos_embed,
|
||||
cond_dim=temporal_cond_dim,
|
||||
).to(device)
|
||||
else:
|
||||
temporal_model = TemporalGRUGenerator(
|
||||
@@ -224,11 +246,9 @@ def main():
|
||||
hidden_dim=temporal_hidden_dim,
|
||||
num_layers=temporal_num_layers,
|
||||
dropout=temporal_dropout,
|
||||
cond_dim=temporal_cond_dim,
|
||||
).to(device)
|
||||
temporal_path = Path(args.model_path).with_name("temporal.pt")
|
||||
if not temporal_path.exists():
|
||||
raise SystemExit(f"missing temporal model file: {temporal_path}")
|
||||
temporal_model.load_state_dict(load_torch_state(str(temporal_path), device))
|
||||
temporal_model.load_state_dict(temporal_state)
|
||||
temporal_model.eval()
|
||||
|
||||
betas = cosine_beta_schedule(args.timesteps).to(device)
|
||||
@@ -279,13 +299,32 @@ def main():
|
||||
for t, row in enumerate(seq):
|
||||
for i, c in enumerate(type1_cols):
|
||||
cond_cont[:, t, i] = float(row[c])
|
||||
mean_vec = torch.tensor([mean[c] for c in type1_cols], dtype=cond_cont.dtype, device=device)
|
||||
std_vec = torch.tensor([std[c] for c in type1_cols], dtype=cond_cont.dtype, device=device)
|
||||
cond_cont = (cond_cont - mean_vec) / std_vec
|
||||
cond_cont = normalize_cont(
|
||||
cond_cont,
|
||||
type1_cols,
|
||||
mean,
|
||||
std,
|
||||
transforms=transforms,
|
||||
quantile_probs=quantile_probs,
|
||||
quantile_values=quantile_values,
|
||||
use_quantile=use_quantile,
|
||||
)
|
||||
|
||||
trend = None
|
||||
if temporal_model is not None:
|
||||
trend = temporal_model.generate(args.batch_size, args.seq_len, device)
|
||||
trend = temporal_model.generate(args.batch_size, args.seq_len, device, cond_cont=cond_cont)
|
||||
if temporal_focus_type4 and type4_cols:
|
||||
type4_model_idx = [model_cont_cols.index(c) for c in type4_cols if c in model_cont_cols]
|
||||
if type4_model_idx:
|
||||
trend_mask = torch.zeros(1, 1, len(model_cont_cols), device=device, dtype=trend.dtype)
|
||||
trend_mask[:, :, type4_model_idx] = 1.0
|
||||
trend = trend * trend_mask
|
||||
elif temporal_exclude_type4 and type4_cols:
|
||||
type4_model_idx = [model_cont_cols.index(c) for c in type4_cols if c in model_cont_cols]
|
||||
if type4_model_idx:
|
||||
trend_mask = torch.ones(1, 1, len(model_cont_cols), device=device, dtype=trend.dtype)
|
||||
trend_mask[:, :, type4_model_idx] = 0.0
|
||||
trend = trend * trend_mask
|
||||
|
||||
for t in reversed(range(args.timesteps)):
|
||||
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
|
||||
|
||||
Reference in New Issue
Block a user