update新结构
This commit is contained in:
@@ -73,3 +73,4 @@ python example/run_pipeline.py --device auto
|
||||
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.
|
||||
- `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU.
|
||||
- Optional feature-graph mixer (`model_use_feature_graph`) adds a learnable relation prior across feature channels.
|
||||
- Optional two-stage temporal backbone (`use_temporal_stage1`) learns a GRU-based sequence trend; diffusion models the residual.
|
||||
|
||||
@@ -35,6 +35,11 @@
|
||||
"model_use_feature_graph": true,
|
||||
"feature_graph_scale": 0.1,
|
||||
"feature_graph_dropout": 0.0,
|
||||
"use_temporal_stage1": true,
|
||||
"temporal_hidden_dim": 256,
|
||||
"temporal_num_layers": 1,
|
||||
"temporal_dropout": 0.0,
|
||||
"temporal_loss_weight": 1.0,
|
||||
"disc_mask_scale": 0.9,
|
||||
"cont_loss_weighting": "inv_std",
|
||||
"cont_loss_eps": 1e-6,
|
||||
|
||||
@@ -13,7 +13,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from data_utils import load_split
|
||||
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
|
||||
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
|
||||
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
|
||||
|
||||
|
||||
@@ -140,6 +140,10 @@ def main():
|
||||
raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob)
|
||||
cont_target = str(cfg.get("cont_target", "eps"))
|
||||
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
|
||||
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
|
||||
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
|
||||
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
|
||||
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
|
||||
|
||||
model = HybridDiffusionModel(
|
||||
cont_dim=len(cont_cols),
|
||||
@@ -166,6 +170,20 @@ def main():
|
||||
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
|
||||
model.eval()
|
||||
|
||||
temporal_model = None
|
||||
if use_temporal_stage1:
|
||||
temporal_model = TemporalGRUGenerator(
|
||||
input_dim=len(cont_cols),
|
||||
hidden_dim=temporal_hidden_dim,
|
||||
num_layers=temporal_num_layers,
|
||||
dropout=temporal_dropout,
|
||||
).to(device)
|
||||
temporal_path = Path(args.model_path).with_name("temporal.pt")
|
||||
if not temporal_path.exists():
|
||||
raise SystemExit(f"missing temporal model file: {temporal_path}")
|
||||
temporal_model.load_state_dict(torch.load(temporal_path, map_location=device, weights_only=True))
|
||||
temporal_model.eval()
|
||||
|
||||
betas = cosine_beta_schedule(args.timesteps).to(device)
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
@@ -192,6 +210,10 @@ def main():
|
||||
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
|
||||
cond = cond_id
|
||||
|
||||
trend = None
|
||||
if temporal_model is not None:
|
||||
trend = temporal_model.generate(args.batch_size, args.seq_len, device)
|
||||
|
||||
for t in reversed(range(args.timesteps)):
|
||||
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
|
||||
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
||||
@@ -232,6 +254,8 @@ def main():
|
||||
)
|
||||
x_disc[:, :, i][mask] = sampled[mask]
|
||||
|
||||
if trend is not None:
|
||||
x_cont = x_cont + trend
|
||||
# move to CPU for export
|
||||
x_cont = x_cont.cpu()
|
||||
x_disc = x_disc.cpu()
|
||||
|
||||
@@ -84,6 +84,46 @@ class FeatureGraphMixer(nn.Module):
|
||||
return x + mixed
|
||||
|
||||
|
||||
class TemporalGRUGenerator(nn.Module):
|
||||
"""Stage-1 temporal generator (autoregressive GRU) for sequence backbone."""
|
||||
|
||||
def __init__(self, input_dim: int, hidden_dim: int = 256, num_layers: int = 1, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
self.start_token = nn.Parameter(torch.zeros(input_dim))
|
||||
self.gru = nn.GRU(
|
||||
input_dim,
|
||||
hidden_dim,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout if num_layers > 1 else 0.0,
|
||||
batch_first=True,
|
||||
)
|
||||
self.out = nn.Linear(hidden_dim, input_dim)
|
||||
|
||||
def forward_teacher(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Teacher-forced next-step prediction. Returns trend sequence and preds."""
|
||||
if x.size(1) < 2:
|
||||
raise ValueError("sequence length must be >= 2 for teacher forcing")
|
||||
inp = x[:, :-1, :]
|
||||
out, _ = self.gru(inp)
|
||||
pred_next = self.out(out)
|
||||
trend = torch.zeros_like(x)
|
||||
trend[:, 0, :] = x[:, 0, :]
|
||||
trend[:, 1:, :] = pred_next
|
||||
return trend, pred_next
|
||||
|
||||
def generate(self, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
|
||||
"""Autoregressively generate a backbone sequence."""
|
||||
h = None
|
||||
prev = self.start_token.unsqueeze(0).expand(batch_size, -1).to(device)
|
||||
outputs = []
|
||||
for _ in range(seq_len):
|
||||
out, h = self.gru(prev.unsqueeze(1), h)
|
||||
nxt = self.out(out.squeeze(1))
|
||||
outputs.append(nxt.unsqueeze(1))
|
||||
prev = nxt
|
||||
return torch.cat(outputs, dim=1)
|
||||
|
||||
|
||||
class HybridDiffusionModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -75,6 +75,7 @@ def main():
|
||||
run([sys.executable, str(base_dir / "evaluate_generated.py"), "--reference", str(ref)])
|
||||
else:
|
||||
run([sys.executable, str(base_dir / "evaluate_generated.py")])
|
||||
run([sys.executable, str(base_dir / "summary_metrics.py")])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -11,6 +11,7 @@ import torch.nn.functional as F
|
||||
|
||||
from data_utils import load_split
|
||||
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
|
||||
from hybrid_diffusion import TemporalGRUGenerator
|
||||
from platform_utils import resolve_device, safe_path, ensure_dir
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
@@ -59,6 +60,10 @@ def main():
|
||||
model_use_feature_graph = bool(cfg.get("model_use_feature_graph", False))
|
||||
feature_graph_scale = float(cfg.get("feature_graph_scale", 0.1))
|
||||
feature_graph_dropout = float(cfg.get("feature_graph_dropout", 0.0))
|
||||
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
|
||||
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
|
||||
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
|
||||
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
|
||||
|
||||
split = load_split(str(SPLIT_PATH))
|
||||
time_col = split.get("time_column", "time")
|
||||
@@ -98,6 +103,20 @@ def main():
|
||||
model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True))
|
||||
model.eval()
|
||||
|
||||
temporal_model = None
|
||||
if use_temporal_stage1:
|
||||
temporal_model = TemporalGRUGenerator(
|
||||
input_dim=len(cont_cols),
|
||||
hidden_dim=temporal_hidden_dim,
|
||||
num_layers=temporal_num_layers,
|
||||
dropout=temporal_dropout,
|
||||
).to(DEVICE)
|
||||
temporal_path = BASE_DIR / "results" / "temporal.pt"
|
||||
if not temporal_path.exists():
|
||||
raise SystemExit(f"missing temporal model file: {temporal_path}")
|
||||
temporal_model.load_state_dict(torch.load(str(temporal_path), map_location=DEVICE, weights_only=True))
|
||||
temporal_model.eval()
|
||||
|
||||
betas = cosine_beta_schedule(timesteps).to(DEVICE)
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
@@ -116,6 +135,10 @@ def main():
|
||||
raise SystemExit("use_condition enabled but no files matched data_glob")
|
||||
cond = torch.randint(0, cond_vocab_size, (batch_size,), device=DEVICE, dtype=torch.long)
|
||||
|
||||
trend = None
|
||||
if temporal_model is not None:
|
||||
trend = temporal_model.generate(batch_size, seq_len, DEVICE)
|
||||
|
||||
for t in reversed(range(timesteps)):
|
||||
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
|
||||
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
||||
@@ -155,6 +178,8 @@ def main():
|
||||
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(BATCH_SIZE, SEQ_LEN)
|
||||
x_disc[:, :, i][mask] = sampled[mask]
|
||||
|
||||
if trend is not None:
|
||||
x_cont = x_cont + trend
|
||||
print("sampled_cont_shape", tuple(x_cont.shape))
|
||||
print("sampled_disc_shape", tuple(x_disc.shape))
|
||||
|
||||
|
||||
40
example/summary_metrics.py
Normal file
40
example/summary_metrics.py
Normal file
@@ -0,0 +1,40 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Print average metrics from eval.json for quick tracking."""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def mean(values):
|
||||
return sum(values) / len(values) if values else None
|
||||
|
||||
|
||||
def main():
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
eval_path = base_dir / "results" / "eval.json"
|
||||
if not eval_path.exists():
|
||||
raise SystemExit(f"missing eval.json: {eval_path}")
|
||||
|
||||
obj = json.loads(eval_path.read_text(encoding="utf-8"))
|
||||
ks = list(obj.get("continuous_ks", {}).values())
|
||||
jsd = list(obj.get("discrete_jsd", {}).values())
|
||||
lag = list(obj.get("continuous_lag1_diff", {}).values())
|
||||
|
||||
avg_ks = mean(ks)
|
||||
avg_jsd = mean(jsd)
|
||||
avg_lag1 = mean(lag)
|
||||
|
||||
print("avg_ks", avg_ks)
|
||||
print("avg_jsd", avg_jsd)
|
||||
print("avg_lag1_diff", avg_lag1)
|
||||
|
||||
history_path = base_dir / "results" / "metrics_history.csv"
|
||||
if not history_path.exists():
|
||||
history_path.write_text("timestamp,avg_ks,avg_jsd,avg_lag1_diff\n", encoding="utf-8")
|
||||
with history_path.open("a", encoding="utf-8") as f:
|
||||
f.write(f"{datetime.utcnow().isoformat()},{avg_ks},{avg_jsd},{avg_lag1}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -14,6 +14,7 @@ import torch.nn.functional as F
|
||||
from data_utils import load_split, windowed_batches
|
||||
from hybrid_diffusion import (
|
||||
HybridDiffusionModel,
|
||||
TemporalGRUGenerator,
|
||||
cosine_beta_schedule,
|
||||
q_sample_continuous,
|
||||
q_sample_discrete,
|
||||
@@ -61,6 +62,11 @@ DEFAULTS = {
|
||||
"model_use_feature_graph": True,
|
||||
"feature_graph_scale": 0.1,
|
||||
"feature_graph_dropout": 0.0,
|
||||
"use_temporal_stage1": True,
|
||||
"temporal_hidden_dim": 256,
|
||||
"temporal_num_layers": 1,
|
||||
"temporal_dropout": 0.0,
|
||||
"temporal_loss_weight": 1.0,
|
||||
"disc_mask_scale": 0.9,
|
||||
"shuffle_buffer": 256,
|
||||
"cont_loss_weighting": "none", # none | inv_std
|
||||
@@ -204,7 +210,19 @@ def main():
|
||||
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
|
||||
eps_scale=float(config.get("eps_scale", 1.0)),
|
||||
).to(device)
|
||||
temporal_model = None
|
||||
if bool(config.get("use_temporal_stage1", False)):
|
||||
temporal_model = TemporalGRUGenerator(
|
||||
input_dim=len(cont_cols),
|
||||
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
|
||||
num_layers=int(config.get("temporal_num_layers", 1)),
|
||||
dropout=float(config.get("temporal_dropout", 0.0)),
|
||||
).to(device)
|
||||
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
|
||||
if temporal_model is not None:
|
||||
opt_temporal = torch.optim.Adam(temporal_model.parameters(), lr=float(config["lr"]))
|
||||
else:
|
||||
opt_temporal = None
|
||||
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
|
||||
|
||||
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
|
||||
@@ -250,10 +268,20 @@ def main():
|
||||
x_cont = x_cont.to(device)
|
||||
x_disc = x_disc.to(device)
|
||||
|
||||
temporal_loss = None
|
||||
x_cont_resid = x_cont
|
||||
trend = None
|
||||
if temporal_model is not None:
|
||||
trend, pred_next = temporal_model.forward_teacher(x_cont)
|
||||
temporal_loss = F.mse_loss(pred_next, x_cont[:, 1:, :])
|
||||
temporal_loss = temporal_loss * float(config.get("temporal_loss_weight", 1.0))
|
||||
trend = trend.detach()
|
||||
x_cont_resid = x_cont - trend
|
||||
|
||||
bsz = x_cont.size(0)
|
||||
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
|
||||
|
||||
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
|
||||
x_cont_t, noise = q_sample_continuous(x_cont_resid, t, alphas_cumprod)
|
||||
|
||||
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
||||
x_disc_t, mask = q_sample_discrete(
|
||||
@@ -268,13 +296,13 @@ def main():
|
||||
|
||||
cont_target = str(config.get("cont_target", "eps"))
|
||||
if cont_target == "x0":
|
||||
x0_target = x_cont
|
||||
x0_target = x_cont_resid
|
||||
if float(config.get("cont_clamp_x0", 0.0)) > 0:
|
||||
x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"]))
|
||||
loss_base = (eps_pred - x0_target) ** 2
|
||||
elif cont_target == "v":
|
||||
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
|
||||
v_target = torch.sqrt(a_bar_t) * noise - torch.sqrt(1.0 - a_bar_t) * x_cont
|
||||
v_target = torch.sqrt(a_bar_t) * noise - torch.sqrt(1.0 - a_bar_t) * x_cont_resid
|
||||
loss_base = (eps_pred - v_target) ** 2
|
||||
else:
|
||||
loss_base = (eps_pred - noise) ** 2
|
||||
@@ -311,7 +339,7 @@ def main():
|
||||
q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95])
|
||||
q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype)
|
||||
# Use normalized space for stable quantiles on x0.
|
||||
x_real = x_cont
|
||||
x_real = x_cont_resid
|
||||
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
|
||||
if cont_target == "x0":
|
||||
x_gen = eps_pred
|
||||
@@ -336,11 +364,18 @@ def main():
|
||||
else:
|
||||
quantile_loss = torch.mean(torch.abs(q_diff))
|
||||
loss = loss + q_weight * quantile_loss
|
||||
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
if float(config.get("grad_clip", 0.0)) > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"]))
|
||||
opt.step()
|
||||
if opt_temporal is not None:
|
||||
opt_temporal.zero_grad()
|
||||
temporal_loss.backward()
|
||||
if float(config.get("grad_clip", 0.0)) > 0:
|
||||
torch.nn.utils.clip_grad_norm_(temporal_model.parameters(), float(config["grad_clip"]))
|
||||
opt_temporal.step()
|
||||
if ema is not None:
|
||||
ema.update(model)
|
||||
|
||||
@@ -375,11 +410,15 @@ def main():
|
||||
}
|
||||
if ema is not None:
|
||||
ckpt["ema"] = ema.state_dict()
|
||||
if temporal_model is not None:
|
||||
ckpt["temporal"] = temporal_model.state_dict()
|
||||
torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt"))
|
||||
|
||||
torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
|
||||
if ema is not None:
|
||||
torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt"))
|
||||
if temporal_model is not None:
|
||||
torch.save(temporal_model.state_dict(), os.path.join(out_dir, "temporal.pt"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user