This commit is contained in:
2026-01-26 22:17:35 +08:00
parent 2e273fb8a2
commit e88b1cab91
9 changed files with 447 additions and 4 deletions

View File

@@ -72,3 +72,4 @@ python example/run_pipeline.py --device auto
- The script only samples the first 5000 rows to stay fast.
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.
- `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU.
- Optional two-stage temporal model (`use_temporal_stage1`) trains a GRU trend backbone first, then diffusion models residuals.

View File

@@ -38,6 +38,12 @@
"cont_target": "x0",
"cont_clamp_x0": 5.0,
"shuffle_buffer": 256,
"use_temporal_stage1": true,
"temporal_hidden_dim": 256,
"temporal_num_layers": 1,
"temporal_dropout": 0.0,
"temporal_epochs": 2,
"temporal_lr": 0.001,
"sample_batch_size": 8,
"sample_seq_len": 128
}

View File

@@ -13,7 +13,7 @@ import torch
import torch.nn.functional as F
from data_utils import load_split
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
@@ -140,6 +140,10 @@ def main():
raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob)
cont_target = str(cfg.get("cont_target", "eps"))
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
model = HybridDiffusionModel(
cont_dim=len(cont_cols),
@@ -163,6 +167,20 @@ def main():
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
model.eval()
temporal_model = None
if use_temporal_stage1:
temporal_model = TemporalGRUGenerator(
input_dim=len(cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_num_layers,
dropout=temporal_dropout,
).to(device)
temporal_path = Path(args.model_path).with_name("temporal.pt")
if not temporal_path.exists():
raise SystemExit(f"missing temporal model file: {temporal_path}")
temporal_model.load_state_dict(torch.load(temporal_path, map_location=device, weights_only=True))
temporal_model.eval()
betas = cosine_beta_schedule(args.timesteps).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
@@ -189,6 +207,10 @@ def main():
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
cond = cond_id
trend = None
if temporal_model is not None:
trend = temporal_model.generate(args.batch_size, args.seq_len, device)
for t in reversed(range(args.timesteps)):
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
@@ -225,6 +247,8 @@ def main():
)
x_disc[:, :, i][mask] = sampled[mask]
if trend is not None:
x_cont = x_cont + trend
# move to CPU for export
x_cont = x_cont.cpu()
x_disc = x_disc.cpu()

View File

@@ -66,6 +66,46 @@ class SinusoidalTimeEmbedding(nn.Module):
return emb
class TemporalGRUGenerator(nn.Module):
"""Stage-1 temporal generator for sequence trend."""
def __init__(self, input_dim: int, hidden_dim: int = 256, num_layers: int = 1, dropout: float = 0.0):
super().__init__()
self.start_token = nn.Parameter(torch.zeros(input_dim))
self.gru = nn.GRU(
input_dim,
hidden_dim,
num_layers=num_layers,
dropout=dropout if num_layers > 1 else 0.0,
batch_first=True,
)
self.out = nn.Linear(hidden_dim, input_dim)
def forward_teacher(self, x: torch.Tensor) -> torch.Tensor:
"""Teacher-forced next-step prediction. Returns trend sequence and preds."""
if x.size(1) < 2:
raise ValueError("sequence length must be >= 2 for teacher forcing")
inp = x[:, :-1, :]
out, _ = self.gru(inp)
pred_next = self.out(out)
trend = torch.zeros_like(x)
trend[:, 0, :] = x[:, 0, :]
trend[:, 1:, :] = pred_next
return trend, pred_next
def generate(self, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
"""Autoregressively generate a backbone sequence."""
h = None
prev = self.start_token.unsqueeze(0).expand(batch_size, -1).to(device)
outputs = []
for _ in range(seq_len):
out, h = self.gru(prev.unsqueeze(1), h)
nxt = self.out(out.squeeze(1))
outputs.append(nxt.unsqueeze(1))
prev = nxt
return torch.cat(outputs, dim=1)
class HybridDiffusionModel(nn.Module):
def __init__(
self,

View File

@@ -75,6 +75,8 @@ def main():
run([sys.executable, str(base_dir / "evaluate_generated.py"), "--reference", str(ref)])
else:
run([sys.executable, str(base_dir / "evaluate_generated.py")])
run([sys.executable, str(base_dir / "summary_metrics.py")])
run([sys.executable, str(base_dir / "summary_metrics.py")])
if __name__ == "__main__":

View File

@@ -10,7 +10,7 @@ import torch
import torch.nn.functional as F
from data_utils import load_split
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
from platform_utils import resolve_device, safe_path, ensure_dir
BASE_DIR = Path(__file__).resolve().parent
@@ -47,6 +47,10 @@ def main():
cond_dim = int(cfg.get("cond_dim", 32))
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
eps_scale = float(cfg.get("eps_scale", 1.0))
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
cont_target = str(cfg.get("cont_target", "eps"))
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
model_time_dim = int(cfg.get("model_time_dim", 64))
@@ -92,6 +96,20 @@ def main():
model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True))
model.eval()
temporal_model = None
if use_temporal_stage1:
temporal_model = TemporalGRUGenerator(
input_dim=len(cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_num_layers,
dropout=temporal_dropout,
).to(DEVICE)
temporal_path = BASE_DIR / "results" / "temporal.pt"
if not temporal_path.exists():
raise SystemExit(f"missing temporal model file: {temporal_path}")
temporal_model.load_state_dict(torch.load(str(temporal_path), map_location=DEVICE, weights_only=True))
temporal_model.eval()
betas = cosine_beta_schedule(timesteps).to(DEVICE)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
@@ -110,6 +128,10 @@ def main():
raise SystemExit("use_condition enabled but no files matched data_glob")
cond = torch.randint(0, cond_vocab_size, (batch_size,), device=DEVICE, dtype=torch.long)
trend = None
if temporal_model is not None:
trend = temporal_model.generate(batch_size, seq_len, DEVICE)
for t in reversed(range(timesteps)):
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
@@ -146,6 +168,8 @@ def main():
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(BATCH_SIZE, SEQ_LEN)
x_disc[:, :, i][mask] = sampled[mask]
if trend is not None:
x_cont = x_cont + trend
print("sampled_cont_shape", tuple(x_cont.shape))
print("sampled_disc_shape", tuple(x_disc.shape))

View File

@@ -0,0 +1,63 @@
#!/usr/bin/env python3
"""Print average metrics from eval.json and compare with previous run."""
import json
from datetime import datetime
from pathlib import Path
def mean(values):
return sum(values) / len(values) if values else None
def parse_last_row(history_path: Path):
if not history_path.exists():
return None
rows = history_path.read_text(encoding="utf-8").strip().splitlines()
if len(rows) < 2:
return None
last = rows[-1].split(",")
if len(last) < 4:
return None
return {
"avg_ks": float(last[1]),
"avg_jsd": float(last[2]),
"avg_lag1_diff": float(last[3]),
}
def main():
base_dir = Path(__file__).resolve().parent
eval_path = base_dir / "results" / "eval.json"
if not eval_path.exists():
raise SystemExit(f"missing eval.json: {eval_path}")
obj = json.loads(eval_path.read_text(encoding="utf-8"))
ks = list(obj.get("continuous_ks", {}).values())
jsd = list(obj.get("discrete_jsd", {}).values())
lag = list(obj.get("continuous_lag1_diff", {}).values())
avg_ks = mean(ks)
avg_jsd = mean(jsd)
avg_lag1 = mean(lag)
history_path = base_dir / "results" / "metrics_history.csv"
prev = parse_last_row(history_path)
if not history_path.exists():
history_path.write_text("timestamp,avg_ks,avg_jsd,avg_lag1_diff\n", encoding="utf-8")
with history_path.open("a", encoding="utf-8") as f:
f.write(f"{datetime.utcnow().isoformat()},{avg_ks},{avg_jsd},{avg_lag1}\n")
print("avg_ks", avg_ks)
print("avg_jsd", avg_jsd)
print("avg_lag1_diff", avg_lag1)
if prev is not None:
print("delta_avg_ks", avg_ks - prev["avg_ks"])
print("delta_avg_jsd", avg_jsd - prev["avg_jsd"])
print("delta_avg_lag1_diff", avg_lag1 - prev["avg_lag1_diff"])
if __name__ == "__main__":
main()

View File

@@ -14,6 +14,7 @@ import torch.nn.functional as F
from data_utils import load_split, windowed_batches
from hybrid_diffusion import (
HybridDiffusionModel,
TemporalGRUGenerator,
cosine_beta_schedule,
q_sample_continuous,
q_sample_discrete,
@@ -64,6 +65,12 @@ DEFAULTS = {
"cont_loss_eps": 1e-6,
"cont_target": "eps", # eps | x0
"cont_clamp_x0": 0.0,
"use_temporal_stage1": True,
"temporal_hidden_dim": 256,
"temporal_num_layers": 1,
"temporal_dropout": 0.0,
"temporal_epochs": 2,
"temporal_lr": 1e-3,
}
@@ -194,6 +201,19 @@ def main():
eps_scale=float(config.get("eps_scale", 1.0)),
).to(device)
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
temporal_model = None
opt_temporal = None
if bool(config.get("use_temporal_stage1", False)):
temporal_model = TemporalGRUGenerator(
input_dim=len(cont_cols),
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
num_layers=int(config.get("temporal_num_layers", 1)),
dropout=float(config.get("temporal_dropout", 0.0)),
).to(device)
opt_temporal = torch.optim.Adam(
temporal_model.parameters(),
lr=float(config.get("temporal_lr", config["lr"])),
)
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
@@ -208,6 +228,37 @@ def main():
with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f:
json.dump(config, f, indent=2)
if temporal_model is not None and opt_temporal is not None:
for epoch in range(int(config.get("temporal_epochs", 1))):
for step, batch in enumerate(
windowed_batches(
data_paths,
cont_cols,
disc_cols,
vocab,
mean,
std,
batch_size=int(config["batch_size"]),
seq_len=int(config["seq_len"]),
max_batches=int(config["max_batches"]),
return_file_id=False,
transforms=transforms,
shuffle_buffer=int(config.get("shuffle_buffer", 0)),
)
):
x_cont, _ = batch
x_cont = x_cont.to(device)
trend, pred_next = temporal_model.forward_teacher(x_cont)
temporal_loss = F.mse_loss(pred_next, x_cont[:, 1:, :])
opt_temporal.zero_grad()
temporal_loss.backward()
if float(config.get("grad_clip", 0.0)) > 0:
torch.nn.utils.clip_grad_norm_(temporal_model.parameters(), float(config["grad_clip"]))
opt_temporal.step()
if step % int(config["log_every"]) == 0:
print("temporal_epoch", epoch, "step", step, "loss", float(temporal_loss))
torch.save(temporal_model.state_dict(), os.path.join(out_dir, "temporal.pt"))
total_step = 0
for epoch in range(int(config["epochs"])):
for step, batch in enumerate(
@@ -235,10 +286,17 @@ def main():
x_cont = x_cont.to(device)
x_disc = x_disc.to(device)
trend = None
if temporal_model is not None:
temporal_model.eval()
with torch.no_grad():
trend, _ = temporal_model.forward_teacher(x_cont)
x_cont_resid = x_cont if trend is None else x_cont - trend
bsz = x_cont.size(0)
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
x_cont_t, noise = q_sample_continuous(x_cont_resid, t, alphas_cumprod)
mask_tokens = torch.tensor(vocab_sizes, device=device)
x_disc_t, mask = q_sample_discrete(
@@ -253,7 +311,7 @@ def main():
cont_target = str(config.get("cont_target", "eps"))
if cont_target == "x0":
x0_target = x_cont
x0_target = x_cont_resid
if float(config.get("cont_clamp_x0", 0.0)) > 0:
x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"]))
loss_base = (eps_pred - x0_target) ** 2
@@ -308,11 +366,15 @@ def main():
}
if ema is not None:
ckpt["ema"] = ema.state_dict()
if temporal_model is not None:
ckpt["temporal"] = temporal_model.state_dict()
torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt"))
torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
if ema is not None:
torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt"))
if temporal_model is not None:
torch.save(temporal_model.state_dict(), os.path.join(out_dir, "temporal.pt"))
if __name__ == "__main__":