12 Commits
ymz ... main

Author SHA1 Message Date
dd4c1e171f update新结构 2026-01-26 19:00:16 +08:00
MZ YANG
f8edee9510 update 2026-01-26 18:36:16 +08:00
cb610281ce update新结构 2026-01-26 18:27:41 +08:00
bc838d7cd7 update ks 2026-01-25 18:13:37 +08:00
MZ YANG
b3c45010a4 update 2026-01-25 18:10:08 +08:00
2a1a9a05c6 update ks 2026-01-25 18:00:28 +08:00
cc10125fbf update ks 2026-01-25 17:55:28 +08:00
MZ YANG
5ef1e465f9 update 2026-01-25 17:50:34 +08:00
5aba14d511 update ks 2026-01-25 17:39:28 +08:00
MZ YANG
bdfc6e2aaa update 2026-01-25 17:35:48 +08:00
MZ YANG
666bc3b8a9 update 2026-01-25 17:28:24 +08:00
a870945e33 update ks 2026-01-25 17:16:57 +08:00
14 changed files with 2353 additions and 1668 deletions

View File

@@ -72,3 +72,5 @@ python example/run_pipeline.py --device auto
- The script only samples the first 5000 rows to stay fast. - The script only samples the first 5000 rows to stay fast.
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it. - `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.
- `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU. - `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU.
- Optional feature-graph mixer (`model_use_feature_graph`) adds a learnable relation prior across feature channels.
- Optional two-stage temporal backbone (`use_temporal_stage1`) learns a GRU-based sequence trend; diffusion models the residual.

View File

@@ -32,11 +32,24 @@
"model_ff_mult": 2, "model_ff_mult": 2,
"model_pos_dim": 64, "model_pos_dim": 64,
"model_use_pos_embed": true, "model_use_pos_embed": true,
"model_use_feature_graph": true,
"feature_graph_scale": 0.1,
"feature_graph_dropout": 0.0,
"use_temporal_stage1": true,
"temporal_hidden_dim": 256,
"temporal_num_layers": 1,
"temporal_dropout": 0.0,
"temporal_loss_weight": 1.0,
"disc_mask_scale": 0.9, "disc_mask_scale": 0.9,
"cont_loss_weighting": "inv_std", "cont_loss_weighting": "inv_std",
"cont_loss_eps": 1e-6, "cont_loss_eps": 1e-6,
"cont_target": "x0", "cont_target": "v",
"cont_clamp_x0": 5.0, "cont_clamp_x0": 5.0,
"quantile_loss_weight": 0.1,
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
"quantile_loss_warmup_steps": 200,
"quantile_loss_clip": 6.0,
"quantile_loss_huber_delta": 1.0,
"shuffle_buffer": 256, "shuffle_buffer": 256,
"sample_batch_size": 8, "sample_batch_size": 8,
"sample_seq_len": 128 "sample_seq_len": 128

View File

@@ -13,7 +13,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from data_utils import load_split from data_utils import load_split
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
@@ -140,6 +140,10 @@ def main():
raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob) raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob)
cont_target = str(cfg.get("cont_target", "eps")) cont_target = str(cfg.get("cont_target", "eps"))
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0)) cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
model = HybridDiffusionModel( model = HybridDiffusionModel(
cont_dim=len(cont_cols), cont_dim=len(cont_cols),
@@ -151,6 +155,9 @@ def main():
ff_mult=int(cfg.get("model_ff_mult", 2)), ff_mult=int(cfg.get("model_ff_mult", 2)),
pos_dim=int(cfg.get("model_pos_dim", 64)), pos_dim=int(cfg.get("model_pos_dim", 64)),
use_pos_embed=bool(cfg.get("model_use_pos_embed", True)), use_pos_embed=bool(cfg.get("model_use_pos_embed", True)),
use_feature_graph=bool(cfg.get("model_use_feature_graph", False)),
feature_graph_scale=float(cfg.get("feature_graph_scale", 0.1)),
feature_graph_dropout=float(cfg.get("feature_graph_dropout", 0.0)),
cond_vocab_size=cond_vocab_size if use_condition else 0, cond_vocab_size=cond_vocab_size if use_condition else 0,
cond_dim=int(cfg.get("cond_dim", 32)), cond_dim=int(cfg.get("cond_dim", 32)),
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)), use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
@@ -163,6 +170,20 @@ def main():
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True)) model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
model.eval() model.eval()
temporal_model = None
if use_temporal_stage1:
temporal_model = TemporalGRUGenerator(
input_dim=len(cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_num_layers,
dropout=temporal_dropout,
).to(device)
temporal_path = Path(args.model_path).with_name("temporal.pt")
if not temporal_path.exists():
raise SystemExit(f"missing temporal model file: {temporal_path}")
temporal_model.load_state_dict(torch.load(temporal_path, map_location=device, weights_only=True))
temporal_model.eval()
betas = cosine_beta_schedule(args.timesteps).to(device) betas = cosine_beta_schedule(args.timesteps).to(device)
alphas = 1.0 - betas alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod = torch.cumprod(alphas, dim=0)
@@ -189,6 +210,10 @@ def main():
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long) cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
cond = cond_id cond = cond_id
trend = None
if temporal_model is not None:
trend = temporal_model.generate(args.batch_size, args.seq_len, device)
for t in reversed(range(args.timesteps)): for t in reversed(range(args.timesteps)):
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long) t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch, cond) eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
@@ -201,6 +226,10 @@ def main():
if cont_clamp_x0 > 0: if cont_clamp_x0 > 0:
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0) x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t) eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
elif cont_target == "v":
v_pred = eps_pred
x0_pred = torch.sqrt(a_bar_t) * x_cont - torch.sqrt(1.0 - a_bar_t) * v_pred
eps_pred = torch.sqrt(1.0 - a_bar_t) * x_cont + torch.sqrt(a_bar_t) * v_pred
coef1 = 1.0 / torch.sqrt(a_t) coef1 = 1.0 / torch.sqrt(a_t)
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t) coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
mean_x = coef1 * (x_cont - coef2 * eps_pred) mean_x = coef1 * (x_cont - coef2 * eps_pred)
@@ -225,6 +254,8 @@ def main():
) )
x_disc[:, :, i][mask] = sampled[mask] x_disc[:, :, i][mask] = sampled[mask]
if trend is not None:
x_cont = x_cont + trend
# move to CPU for export # move to CPU for export
x_cont = x_cont.cpu() x_cont = x_cont.cpu()
x_disc = x_disc.cpu() x_disc = x_disc.cpu()

View File

@@ -66,6 +66,64 @@ class SinusoidalTimeEmbedding(nn.Module):
return emb return emb
class FeatureGraphMixer(nn.Module):
"""Learnable feature relation mixer (dataset-agnostic)."""
def __init__(self, dim: int, scale: float = 0.1, dropout: float = 0.0):
super().__init__()
self.scale = scale
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.A = nn.Parameter(torch.zeros(dim, dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, T, D)
# Symmetric relation to stabilize
A = (self.A + self.A.t()) * 0.5
mixed = torch.matmul(x, A) * self.scale
mixed = self.dropout(mixed)
return x + mixed
class TemporalGRUGenerator(nn.Module):
"""Stage-1 temporal generator (autoregressive GRU) for sequence backbone."""
def __init__(self, input_dim: int, hidden_dim: int = 256, num_layers: int = 1, dropout: float = 0.0):
super().__init__()
self.start_token = nn.Parameter(torch.zeros(input_dim))
self.gru = nn.GRU(
input_dim,
hidden_dim,
num_layers=num_layers,
dropout=dropout if num_layers > 1 else 0.0,
batch_first=True,
)
self.out = nn.Linear(hidden_dim, input_dim)
def forward_teacher(self, x: torch.Tensor) -> torch.Tensor:
"""Teacher-forced next-step prediction. Returns trend sequence and preds."""
if x.size(1) < 2:
raise ValueError("sequence length must be >= 2 for teacher forcing")
inp = x[:, :-1, :]
out, _ = self.gru(inp)
pred_next = self.out(out)
trend = torch.zeros_like(x)
trend[:, 0, :] = x[:, 0, :]
trend[:, 1:, :] = pred_next
return trend, pred_next
def generate(self, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
"""Autoregressively generate a backbone sequence."""
h = None
prev = self.start_token.unsqueeze(0).expand(batch_size, -1).to(device)
outputs = []
for _ in range(seq_len):
out, h = self.gru(prev.unsqueeze(1), h)
nxt = self.out(out.squeeze(1))
outputs.append(nxt.unsqueeze(1))
prev = nxt
return torch.cat(outputs, dim=1)
class HybridDiffusionModel(nn.Module): class HybridDiffusionModel(nn.Module):
def __init__( def __init__(
self, self,
@@ -78,6 +136,9 @@ class HybridDiffusionModel(nn.Module):
ff_mult: int = 2, ff_mult: int = 2,
pos_dim: int = 64, pos_dim: int = 64,
use_pos_embed: bool = True, use_pos_embed: bool = True,
use_feature_graph: bool = False,
feature_graph_scale: float = 0.1,
feature_graph_dropout: float = 0.0,
cond_vocab_size: int = 0, cond_vocab_size: int = 0,
cond_dim: int = 32, cond_dim: int = 32,
use_tanh_eps: bool = False, use_tanh_eps: bool = False,
@@ -92,6 +153,7 @@ class HybridDiffusionModel(nn.Module):
self.eps_scale = eps_scale self.eps_scale = eps_scale
self.pos_dim = pos_dim self.pos_dim = pos_dim
self.use_pos_embed = use_pos_embed self.use_pos_embed = use_pos_embed
self.use_feature_graph = use_feature_graph
self.cond_vocab_size = cond_vocab_size self.cond_vocab_size = cond_vocab_size
self.cond_dim = cond_dim self.cond_dim = cond_dim
@@ -106,8 +168,17 @@ class HybridDiffusionModel(nn.Module):
disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds) disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds)
self.cont_proj = nn.Linear(cont_dim, cont_dim) self.cont_proj = nn.Linear(cont_dim, cont_dim)
self.feature_dim = cont_dim + disc_embed_dim
if use_feature_graph:
self.feature_graph = FeatureGraphMixer(
self.feature_dim,
scale=feature_graph_scale,
dropout=feature_graph_dropout,
)
else:
self.feature_graph = None
pos_dim = pos_dim if use_pos_embed else 0 pos_dim = pos_dim if use_pos_embed else 0
in_dim = cont_dim + disc_embed_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0) in_dim = self.feature_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0)
self.in_proj = nn.Linear(in_dim, hidden_dim) self.in_proj = nn.Linear(in_dim, hidden_dim)
self.backbone = nn.GRU( self.backbone = nn.GRU(
hidden_dim, hidden_dim,
@@ -149,7 +220,10 @@ class HybridDiffusionModel(nn.Module):
cond_feat = self.cond_embed(cond).unsqueeze(1).expand(-1, x_cont.size(1), -1) cond_feat = self.cond_embed(cond).unsqueeze(1).expand(-1, x_cont.size(1), -1)
cont_feat = self.cont_proj(x_cont) cont_feat = self.cont_proj(x_cont)
parts = [cont_feat, disc_feat, time_emb] feat = torch.cat([cont_feat, disc_feat], dim=-1)
if self.feature_graph is not None:
feat = self.feature_graph(feat)
parts = [feat, time_emb]
if pos_emb is not None: if pos_emb is not None:
parts.append(pos_emb.unsqueeze(0).expand(x_cont.size(0), -1, -1)) parts.append(pos_emb.unsqueeze(0).expand(x_cont.size(0), -1, -1))
if cond_feat is not None: if cond_feat is not None:

View File

@@ -34,6 +34,7 @@ def main():
loss = [] loss = []
loss_cont = [] loss_cont = []
loss_disc = [] loss_disc = []
loss_quant = []
with log_path.open("r", encoding="utf-8", newline="") as f: with log_path.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f) reader = csv.DictReader(f)
@@ -42,6 +43,8 @@ def main():
loss.append(float(row["loss"])) loss.append(float(row["loss"]))
loss_cont.append(float(row["loss_cont"])) loss_cont.append(float(row["loss_cont"]))
loss_disc.append(float(row["loss_disc"])) loss_disc.append(float(row["loss_disc"]))
if "loss_quantile" in row:
loss_quant.append(float(row["loss_quantile"]))
if not steps: if not steps:
raise SystemExit("no rows in log file: %s" % log_path) raise SystemExit("no rows in log file: %s" % log_path)
@@ -50,6 +53,8 @@ def main():
plt.plot(steps, loss, label="total") plt.plot(steps, loss, label="total")
plt.plot(steps, loss_cont, label="continuous") plt.plot(steps, loss_cont, label="continuous")
plt.plot(steps, loss_disc, label="discrete") plt.plot(steps, loss_disc, label="discrete")
if loss_quant:
plt.plot(steps, loss_quant, label="quantile")
plt.xlabel("step") plt.xlabel("step")
plt.ylabel("loss") plt.ylabel("loss")
plt.title("Training Loss") plt.title("Training Loss")

View File

@@ -32,12 +32,26 @@
"model_ff_mult": 2, "model_ff_mult": 2,
"model_pos_dim": 64, "model_pos_dim": 64,
"model_use_pos_embed": true, "model_use_pos_embed": true,
"model_use_feature_graph": true,
"feature_graph_scale": 0.1,
"feature_graph_dropout": 0.0,
"disc_mask_scale": 0.9, "disc_mask_scale": 0.9,
"shuffle_buffer": 256, "shuffle_buffer": 256,
"cont_loss_weighting": "inv_std", "cont_loss_weighting": "inv_std",
"cont_loss_eps": 1e-06, "cont_loss_eps": 1e-06,
"cont_target": "x0", "cont_target": "v",
"cont_clamp_x0": 5.0, "cont_clamp_x0": 5.0,
"quantile_loss_weight": 0.1,
"quantile_points": [
0.05,
0.25,
0.5,
0.75,
0.95
],
"quantile_loss_warmup_steps": 200,
"quantile_loss_clip": 6.0,
"quantile_loss_huber_delta": 1.0,
"sample_batch_size": 8, "sample_batch_size": 8,
"sample_seq_len": 128 "sample_seq_len": 128
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,61 +1,61 @@
epoch,step,loss,loss_cont,loss_disc epoch,step,loss,loss_cont,loss_disc,loss_quantile
0,0,9801.450195,14001.333008,1.724242 0,0,11527.750000,16467.474609,1.724242,0.113109
0,10,8648.388672,12354.390625,1.050415 0,10,8955.048828,12792.469727,1.065955,0.142707
0,20,6285.666992,8979.191406,0.775979 0,20,123049.507812,175784.562500,1.048191,0.181546
0,30,6296.939941,8995.372070,0.598658 0,30,30947.671875,44210.511719,1.034107,0.224363
0,40,7126.128906,10180.011719,0.402291 0,40,11166.339844,15951.691406,0.512211,0.163978
0,50,5381.071289,7687.099121,0.340463 0,50,14919.276367,21313.089844,0.369016,0.101041
1,0,17876.904297,25538.212891,0.522857 1,0,1113425.000000,1590606.875000,0.667305,0.236102
1,10,9174.909180,13106.868164,0.338315 1,10,38804.527344,55433.980469,2.453400,0.238323
1,20,7635.462891,10907.713867,0.211883 1,20,138075.984375,197251.281250,0.288013,0.157941
1,30,5425.212402,7750.165039,0.323606 1,30,56904.078125,81291.343750,0.429261,0.174654
1,40,4372.716309,6246.610840,0.296102 1,40,10662.019531,15231.303711,0.343701,0.079876
1,50,3846.437988,5494.793945,0.273941 1,50,9890.013672,14128.458008,0.292848,0.088391
2,0,17057.958984,24368.269531,0.564671 2,0,912398.500000,1303426.125000,0.795286,0.256503
2,10,7089.009766,10127.021484,0.315390 2,10,12546.185547,17922.542969,1.328752,0.110444
2,20,4230.856445,6043.994141,0.201429 2,20,109676.710938,156680.906250,0.245523,0.132549
2,30,3744.107910,5348.593262,0.309118 2,30,49427.507812,70610.578125,0.327188,0.088591
2,40,3531.041992,5044.219238,0.295744 2,40,27778.673828,39683.660156,0.345683,0.102361
2,50,3570.459229,5100.528320,0.297530 2,50,10311.509766,14730.566406,0.350199,0.091674
3,0,14367.601562,20524.917969,0.529623 3,0,1040308.062500,1486154.000000,0.807157,0.279805
3,10,6734.334473,9620.347656,0.304721 3,10,64799.246094,92570.117188,0.485949,0.198213
3,20,6140.179688,8771.602539,0.193836 3,20,336018.000000,480025.531250,0.410466,0.118048
3,30,4089.454102,5841.940918,0.317864 3,30,94216.312500,134594.562500,0.355044,0.114209
3,40,3553.830811,5076.785645,0.269531 3,40,19988.919922,28555.457031,0.298066,0.094291
3,50,3590.448242,5129.088867,0.287063 3,50,9181.969727,13116.940430,0.326489,0.137392
4,0,14410.816406,20586.648438,0.543822 4,0,741176.187500,1058822.750000,0.850633,0.233118
4,10,6411.443359,9159.058594,0.341742 4,10,39252.617188,56074.410156,1.690607,0.227370
4,20,3816.795166,5452.479492,0.198213 4,20,108992.304688,155703.140625,0.258920,0.285944
4,30,4069.170898,5812.959961,0.329676 4,30,37115.253906,53021.632812,0.337674,0.131537
4,40,3484.921631,4978.332520,0.296284 4,40,19358.708984,27655.123047,0.349937,0.169446
4,50,2802.801514,4003.873779,0.299286 4,50,11434.291992,16334.540039,0.351904,0.089116
5,0,13335.293945,19050.201172,0.509557 5,0,845658.312500,1208083.000000,0.922483,0.285983
5,10,5531.156738,7901.527344,0.293409 5,10,130569.406250,186527.515625,0.405198,0.227639
5,20,3844.260010,5491.696777,0.241263 5,20,245780.390625,351114.687500,0.301236,0.147030
5,30,3619.303223,5170.297363,0.317237 5,30,42017.671875,60025.066406,0.372036,0.117895
5,40,3492.172852,4988.697754,0.281641 5,40,11496.740234,16423.779297,0.286911,0.085701
5,50,3069.457275,4384.815918,0.287269 5,50,8891.728516,12702.317383,0.322913,0.099181
6,0,8740.982422,12486.912109,0.483061 6,0,617909.687500,882727.812500,0.834663,0.205251
6,10,6110.571777,8729.239258,0.347929 6,10,18171.734375,25959.302734,0.677465,0.190570
6,20,3350.194092,4785.889160,0.239650 6,20,423716.187500,605308.687500,0.464189,0.156125
6,30,3008.237549,4297.353516,0.300327 6,30,48133.507812,68761.914062,0.478786,0.221465
6,40,2944.483887,4206.288574,0.273457 6,40,20350.281250,29071.666016,0.343182,0.118739
6,50,3018.033447,4311.365234,0.259749 6,50,11372.219727,16245.889648,0.289177,0.101218
7,0,6784.341309,9691.704102,0.495378 7,0,534133.500000,763047.500000,0.732465,0.239962
7,10,4946.872559,7066.822754,0.321541 7,10,19025.574219,27178.703125,1.551094,0.183023
7,20,2816.704102,4023.745361,0.274515 7,20,58042.757812,82918.078125,0.298854,0.146532
7,30,2991.350830,4273.229980,0.299238 7,30,11810.656250,16872.201172,0.353035,0.092791
7,40,3023.450684,4319.083984,0.307114 7,40,9733.656250,13905.077148,0.312043,0.089972
7,50,2944.912598,4206.909668,0.253002 7,50,9069.159180,12955.803711,0.280795,0.129749
8,0,7364.851562,10521.002930,0.498283 8,0,650366.937500,929095.250000,0.787194,0.263791
8,10,3897.301025,5567.439453,0.311517 8,10,13121.248047,18744.031250,1.376233,0.141053
8,20,3313.474854,4733.448242,0.203364 8,20,19326.507812,27609.167969,0.243424,0.183426
8,30,2697.139648,3852.940430,0.271355 8,30,12904.376953,18434.667969,0.334577,0.092547
8,40,2955.225342,4221.633301,0.273175 8,40,9682.833984,13832.478516,0.296229,0.109518
8,50,2932.081787,4188.575684,0.262651 8,50,8866.144531,12665.773438,0.312481,0.099564
9,0,4065.651611,5807.854004,0.513017 9,0,265689.562500,379556.156250,0.738840,0.186122
9,10,4358.108398,6225.728516,0.329111 9,10,9944.607422,14206.222656,0.806142,0.096763
9,20,3417.019043,4881.362305,0.218488 9,20,22091.175781,31558.679688,0.285403,0.131950
9,30,2917.136719,4167.226562,0.260805 9,30,19346.830078,27638.111328,0.460825,0.131496
9,40,2786.277832,3980.287109,0.256345 9,40,10279.702148,14685.139648,0.314872,0.095758
9,50,2602.660645,3717.971680,0.268723 9,50,10420.340820,14886.050781,0.314008,0.112904
1 epoch step loss loss_cont loss_disc loss_quantile
2 0 0 9801.450195 11527.750000 14001.333008 16467.474609 1.724242 0.113109
3 0 10 8648.388672 8955.048828 12354.390625 12792.469727 1.050415 1.065955 0.142707
4 0 20 6285.666992 123049.507812 8979.191406 175784.562500 0.775979 1.048191 0.181546
5 0 30 6296.939941 30947.671875 8995.372070 44210.511719 0.598658 1.034107 0.224363
6 0 40 7126.128906 11166.339844 10180.011719 15951.691406 0.402291 0.512211 0.163978
7 0 50 5381.071289 14919.276367 7687.099121 21313.089844 0.340463 0.369016 0.101041
8 1 0 17876.904297 1113425.000000 25538.212891 1590606.875000 0.522857 0.667305 0.236102
9 1 10 9174.909180 38804.527344 13106.868164 55433.980469 0.338315 2.453400 0.238323
10 1 20 7635.462891 138075.984375 10907.713867 197251.281250 0.211883 0.288013 0.157941
11 1 30 5425.212402 56904.078125 7750.165039 81291.343750 0.323606 0.429261 0.174654
12 1 40 4372.716309 10662.019531 6246.610840 15231.303711 0.296102 0.343701 0.079876
13 1 50 3846.437988 9890.013672 5494.793945 14128.458008 0.273941 0.292848 0.088391
14 2 0 17057.958984 912398.500000 24368.269531 1303426.125000 0.564671 0.795286 0.256503
15 2 10 7089.009766 12546.185547 10127.021484 17922.542969 0.315390 1.328752 0.110444
16 2 20 4230.856445 109676.710938 6043.994141 156680.906250 0.201429 0.245523 0.132549
17 2 30 3744.107910 49427.507812 5348.593262 70610.578125 0.309118 0.327188 0.088591
18 2 40 3531.041992 27778.673828 5044.219238 39683.660156 0.295744 0.345683 0.102361
19 2 50 3570.459229 10311.509766 5100.528320 14730.566406 0.297530 0.350199 0.091674
20 3 0 14367.601562 1040308.062500 20524.917969 1486154.000000 0.529623 0.807157 0.279805
21 3 10 6734.334473 64799.246094 9620.347656 92570.117188 0.304721 0.485949 0.198213
22 3 20 6140.179688 336018.000000 8771.602539 480025.531250 0.193836 0.410466 0.118048
23 3 30 4089.454102 94216.312500 5841.940918 134594.562500 0.317864 0.355044 0.114209
24 3 40 3553.830811 19988.919922 5076.785645 28555.457031 0.269531 0.298066 0.094291
25 3 50 3590.448242 9181.969727 5129.088867 13116.940430 0.287063 0.326489 0.137392
26 4 0 14410.816406 741176.187500 20586.648438 1058822.750000 0.543822 0.850633 0.233118
27 4 10 6411.443359 39252.617188 9159.058594 56074.410156 0.341742 1.690607 0.227370
28 4 20 3816.795166 108992.304688 5452.479492 155703.140625 0.198213 0.258920 0.285944
29 4 30 4069.170898 37115.253906 5812.959961 53021.632812 0.329676 0.337674 0.131537
30 4 40 3484.921631 19358.708984 4978.332520 27655.123047 0.296284 0.349937 0.169446
31 4 50 2802.801514 11434.291992 4003.873779 16334.540039 0.299286 0.351904 0.089116
32 5 0 13335.293945 845658.312500 19050.201172 1208083.000000 0.509557 0.922483 0.285983
33 5 10 5531.156738 130569.406250 7901.527344 186527.515625 0.293409 0.405198 0.227639
34 5 20 3844.260010 245780.390625 5491.696777 351114.687500 0.241263 0.301236 0.147030
35 5 30 3619.303223 42017.671875 5170.297363 60025.066406 0.317237 0.372036 0.117895
36 5 40 3492.172852 11496.740234 4988.697754 16423.779297 0.281641 0.286911 0.085701
37 5 50 3069.457275 8891.728516 4384.815918 12702.317383 0.287269 0.322913 0.099181
38 6 0 8740.982422 617909.687500 12486.912109 882727.812500 0.483061 0.834663 0.205251
39 6 10 6110.571777 18171.734375 8729.239258 25959.302734 0.347929 0.677465 0.190570
40 6 20 3350.194092 423716.187500 4785.889160 605308.687500 0.239650 0.464189 0.156125
41 6 30 3008.237549 48133.507812 4297.353516 68761.914062 0.300327 0.478786 0.221465
42 6 40 2944.483887 20350.281250 4206.288574 29071.666016 0.273457 0.343182 0.118739
43 6 50 3018.033447 11372.219727 4311.365234 16245.889648 0.259749 0.289177 0.101218
44 7 0 6784.341309 534133.500000 9691.704102 763047.500000 0.495378 0.732465 0.239962
45 7 10 4946.872559 19025.574219 7066.822754 27178.703125 0.321541 1.551094 0.183023
46 7 20 2816.704102 58042.757812 4023.745361 82918.078125 0.274515 0.298854 0.146532
47 7 30 2991.350830 11810.656250 4273.229980 16872.201172 0.299238 0.353035 0.092791
48 7 40 3023.450684 9733.656250 4319.083984 13905.077148 0.307114 0.312043 0.089972
49 7 50 2944.912598 9069.159180 4206.909668 12955.803711 0.253002 0.280795 0.129749
50 8 0 7364.851562 650366.937500 10521.002930 929095.250000 0.498283 0.787194 0.263791
51 8 10 3897.301025 13121.248047 5567.439453 18744.031250 0.311517 1.376233 0.141053
52 8 20 3313.474854 19326.507812 4733.448242 27609.167969 0.203364 0.243424 0.183426
53 8 30 2697.139648 12904.376953 3852.940430 18434.667969 0.271355 0.334577 0.092547
54 8 40 2955.225342 9682.833984 4221.633301 13832.478516 0.273175 0.296229 0.109518
55 8 50 2932.081787 8866.144531 4188.575684 12665.773438 0.262651 0.312481 0.099564
56 9 0 4065.651611 265689.562500 5807.854004 379556.156250 0.513017 0.738840 0.186122
57 9 10 4358.108398 9944.607422 6225.728516 14206.222656 0.329111 0.806142 0.096763
58 9 20 3417.019043 22091.175781 4881.362305 31558.679688 0.218488 0.285403 0.131950
59 9 30 2917.136719 19346.830078 4167.226562 27638.111328 0.260805 0.460825 0.131496
60 9 40 2786.277832 10279.702148 3980.287109 14685.139648 0.256345 0.314872 0.095758
61 9 50 2602.660645 10420.340820 3717.971680 14886.050781 0.268723 0.314008 0.112904

View File

@@ -75,6 +75,7 @@ def main():
run([sys.executable, str(base_dir / "evaluate_generated.py"), "--reference", str(ref)]) run([sys.executable, str(base_dir / "evaluate_generated.py"), "--reference", str(ref)])
else: else:
run([sys.executable, str(base_dir / "evaluate_generated.py")]) run([sys.executable, str(base_dir / "evaluate_generated.py")])
run([sys.executable, str(base_dir / "summary_metrics.py")])
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -11,6 +11,7 @@ import torch.nn.functional as F
from data_utils import load_split from data_utils import load_split
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
from hybrid_diffusion import TemporalGRUGenerator
from platform_utils import resolve_device, safe_path, ensure_dir from platform_utils import resolve_device, safe_path, ensure_dir
BASE_DIR = Path(__file__).resolve().parent BASE_DIR = Path(__file__).resolve().parent
@@ -56,6 +57,13 @@ def main():
model_ff_mult = int(cfg.get("model_ff_mult", 2)) model_ff_mult = int(cfg.get("model_ff_mult", 2))
model_pos_dim = int(cfg.get("model_pos_dim", 64)) model_pos_dim = int(cfg.get("model_pos_dim", 64))
model_use_pos = bool(cfg.get("model_use_pos_embed", True)) model_use_pos = bool(cfg.get("model_use_pos_embed", True))
model_use_feature_graph = bool(cfg.get("model_use_feature_graph", False))
feature_graph_scale = float(cfg.get("feature_graph_scale", 0.1))
feature_graph_dropout = float(cfg.get("feature_graph_dropout", 0.0))
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
split = load_split(str(SPLIT_PATH)) split = load_split(str(SPLIT_PATH))
time_col = split.get("time_column", "time") time_col = split.get("time_column", "time")
@@ -83,6 +91,9 @@ def main():
ff_mult=model_ff_mult, ff_mult=model_ff_mult,
pos_dim=model_pos_dim, pos_dim=model_pos_dim,
use_pos_embed=model_use_pos, use_pos_embed=model_use_pos,
use_feature_graph=model_use_feature_graph,
feature_graph_scale=feature_graph_scale,
feature_graph_dropout=feature_graph_dropout,
cond_vocab_size=cond_vocab_size, cond_vocab_size=cond_vocab_size,
cond_dim=cond_dim, cond_dim=cond_dim,
use_tanh_eps=use_tanh_eps, use_tanh_eps=use_tanh_eps,
@@ -92,6 +103,20 @@ def main():
model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True)) model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True))
model.eval() model.eval()
temporal_model = None
if use_temporal_stage1:
temporal_model = TemporalGRUGenerator(
input_dim=len(cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_num_layers,
dropout=temporal_dropout,
).to(DEVICE)
temporal_path = BASE_DIR / "results" / "temporal.pt"
if not temporal_path.exists():
raise SystemExit(f"missing temporal model file: {temporal_path}")
temporal_model.load_state_dict(torch.load(str(temporal_path), map_location=DEVICE, weights_only=True))
temporal_model.eval()
betas = cosine_beta_schedule(timesteps).to(DEVICE) betas = cosine_beta_schedule(timesteps).to(DEVICE)
alphas = 1.0 - betas alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod = torch.cumprod(alphas, dim=0)
@@ -110,19 +135,26 @@ def main():
raise SystemExit("use_condition enabled but no files matched data_glob") raise SystemExit("use_condition enabled but no files matched data_glob")
cond = torch.randint(0, cond_vocab_size, (batch_size,), device=DEVICE, dtype=torch.long) cond = torch.randint(0, cond_vocab_size, (batch_size,), device=DEVICE, dtype=torch.long)
trend = None
if temporal_model is not None:
trend = temporal_model.generate(batch_size, seq_len, DEVICE)
for t in reversed(range(timesteps)): for t in reversed(range(timesteps)):
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long) t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch, cond) eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
# Continuous reverse step (DDPM): x_{t-1} mean
a_t = alphas[t]
a_bar_t = alphas_cumprod[t]
if cont_target == "x0": if cont_target == "x0":
x0_pred = eps_pred x0_pred = eps_pred
if cont_clamp_x0 > 0: if cont_clamp_x0 > 0:
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0) x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t) eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
elif cont_target == "v":
# Continuous reverse step (DDPM): x_{t-1} mean v_pred = eps_pred
a_t = alphas[t] x0_pred = torch.sqrt(a_bar_t) * x_cont - torch.sqrt(1.0 - a_bar_t) * v_pred
a_bar_t = alphas_cumprod[t] eps_pred = torch.sqrt(1.0 - a_bar_t) * x_cont + torch.sqrt(a_bar_t) * v_pred
coef1 = 1.0 / torch.sqrt(a_t) coef1 = 1.0 / torch.sqrt(a_t)
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t) coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
mean = coef1 * (x_cont - coef2 * eps_pred) mean = coef1 * (x_cont - coef2 * eps_pred)
@@ -146,6 +178,8 @@ def main():
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(BATCH_SIZE, SEQ_LEN) sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(BATCH_SIZE, SEQ_LEN)
x_disc[:, :, i][mask] = sampled[mask] x_disc[:, :, i][mask] = sampled[mask]
if trend is not None:
x_cont = x_cont + trend
print("sampled_cont_shape", tuple(x_cont.shape)) print("sampled_cont_shape", tuple(x_cont.shape))
print("sampled_disc_shape", tuple(x_disc.shape)) print("sampled_disc_shape", tuple(x_disc.shape))

View File

@@ -0,0 +1,40 @@
#!/usr/bin/env python3
"""Print average metrics from eval.json for quick tracking."""
import json
from datetime import datetime
from pathlib import Path
def mean(values):
return sum(values) / len(values) if values else None
def main():
base_dir = Path(__file__).resolve().parent
eval_path = base_dir / "results" / "eval.json"
if not eval_path.exists():
raise SystemExit(f"missing eval.json: {eval_path}")
obj = json.loads(eval_path.read_text(encoding="utf-8"))
ks = list(obj.get("continuous_ks", {}).values())
jsd = list(obj.get("discrete_jsd", {}).values())
lag = list(obj.get("continuous_lag1_diff", {}).values())
avg_ks = mean(ks)
avg_jsd = mean(jsd)
avg_lag1 = mean(lag)
print("avg_ks", avg_ks)
print("avg_jsd", avg_jsd)
print("avg_lag1_diff", avg_lag1)
history_path = base_dir / "results" / "metrics_history.csv"
if not history_path.exists():
history_path.write_text("timestamp,avg_ks,avg_jsd,avg_lag1_diff\n", encoding="utf-8")
with history_path.open("a", encoding="utf-8") as f:
f.write(f"{datetime.utcnow().isoformat()},{avg_ks},{avg_jsd},{avg_lag1}\n")
if __name__ == "__main__":
main()

View File

@@ -14,6 +14,7 @@ import torch.nn.functional as F
from data_utils import load_split, windowed_batches from data_utils import load_split, windowed_batches
from hybrid_diffusion import ( from hybrid_diffusion import (
HybridDiffusionModel, HybridDiffusionModel,
TemporalGRUGenerator,
cosine_beta_schedule, cosine_beta_schedule,
q_sample_continuous, q_sample_continuous,
q_sample_discrete, q_sample_discrete,
@@ -58,12 +59,25 @@ DEFAULTS = {
"model_ff_mult": 2, "model_ff_mult": 2,
"model_pos_dim": 64, "model_pos_dim": 64,
"model_use_pos_embed": True, "model_use_pos_embed": True,
"model_use_feature_graph": True,
"feature_graph_scale": 0.1,
"feature_graph_dropout": 0.0,
"use_temporal_stage1": True,
"temporal_hidden_dim": 256,
"temporal_num_layers": 1,
"temporal_dropout": 0.0,
"temporal_loss_weight": 1.0,
"disc_mask_scale": 0.9, "disc_mask_scale": 0.9,
"shuffle_buffer": 256, "shuffle_buffer": 256,
"cont_loss_weighting": "none", # none | inv_std "cont_loss_weighting": "none", # none | inv_std
"cont_loss_eps": 1e-6, "cont_loss_eps": 1e-6,
"cont_target": "eps", # eps | x0 "cont_target": "eps", # eps | x0 | v
"cont_clamp_x0": 0.0, "cont_clamp_x0": 0.0,
"quantile_loss_weight": 0.0,
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
"quantile_loss_warmup_steps": 200,
"quantile_loss_clip": 6.0,
"quantile_loss_huber_delta": 1.0,
} }
@@ -188,12 +202,27 @@ def main():
ff_mult=int(config.get("model_ff_mult", 2)), ff_mult=int(config.get("model_ff_mult", 2)),
pos_dim=int(config.get("model_pos_dim", 64)), pos_dim=int(config.get("model_pos_dim", 64)),
use_pos_embed=bool(config.get("model_use_pos_embed", True)), use_pos_embed=bool(config.get("model_use_pos_embed", True)),
use_feature_graph=bool(config.get("model_use_feature_graph", False)),
feature_graph_scale=float(config.get("feature_graph_scale", 0.1)),
feature_graph_dropout=float(config.get("feature_graph_dropout", 0.0)),
cond_vocab_size=cond_vocab_size, cond_vocab_size=cond_vocab_size,
cond_dim=int(config.get("cond_dim", 32)), cond_dim=int(config.get("cond_dim", 32)),
use_tanh_eps=bool(config.get("use_tanh_eps", False)), use_tanh_eps=bool(config.get("use_tanh_eps", False)),
eps_scale=float(config.get("eps_scale", 1.0)), eps_scale=float(config.get("eps_scale", 1.0)),
).to(device) ).to(device)
temporal_model = None
if bool(config.get("use_temporal_stage1", False)):
temporal_model = TemporalGRUGenerator(
input_dim=len(cont_cols),
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
num_layers=int(config.get("temporal_num_layers", 1)),
dropout=float(config.get("temporal_dropout", 0.0)),
).to(device)
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"])) opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
if temporal_model is not None:
opt_temporal = torch.optim.Adam(temporal_model.parameters(), lr=float(config["lr"]))
else:
opt_temporal = None
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
betas = cosine_beta_schedule(int(config["timesteps"])).to(device) betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
@@ -203,7 +232,11 @@ def main():
os.makedirs(config["out_dir"], exist_ok=True) os.makedirs(config["out_dir"], exist_ok=True)
out_dir = safe_path(config["out_dir"]) out_dir = safe_path(config["out_dir"])
log_path = os.path.join(out_dir, "train_log.csv") log_path = os.path.join(out_dir, "train_log.csv")
use_quantile = float(config.get("quantile_loss_weight", 0.0)) > 0
with open(log_path, "w", encoding="utf-8") as f: with open(log_path, "w", encoding="utf-8") as f:
if use_quantile:
f.write("epoch,step,loss,loss_cont,loss_disc,loss_quantile\n")
else:
f.write("epoch,step,loss,loss_cont,loss_disc\n") f.write("epoch,step,loss,loss_cont,loss_disc\n")
with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f: with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f:
json.dump(config, f, indent=2) json.dump(config, f, indent=2)
@@ -235,10 +268,20 @@ def main():
x_cont = x_cont.to(device) x_cont = x_cont.to(device)
x_disc = x_disc.to(device) x_disc = x_disc.to(device)
temporal_loss = None
x_cont_resid = x_cont
trend = None
if temporal_model is not None:
trend, pred_next = temporal_model.forward_teacher(x_cont)
temporal_loss = F.mse_loss(pred_next, x_cont[:, 1:, :])
temporal_loss = temporal_loss * float(config.get("temporal_loss_weight", 1.0))
trend = trend.detach()
x_cont_resid = x_cont - trend
bsz = x_cont.size(0) bsz = x_cont.size(0)
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device) t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod) x_cont_t, noise = q_sample_continuous(x_cont_resid, t, alphas_cumprod)
mask_tokens = torch.tensor(vocab_sizes, device=device) mask_tokens = torch.tensor(vocab_sizes, device=device)
x_disc_t, mask = q_sample_discrete( x_disc_t, mask = q_sample_discrete(
@@ -253,10 +296,14 @@ def main():
cont_target = str(config.get("cont_target", "eps")) cont_target = str(config.get("cont_target", "eps"))
if cont_target == "x0": if cont_target == "x0":
x0_target = x_cont x0_target = x_cont_resid
if float(config.get("cont_clamp_x0", 0.0)) > 0: if float(config.get("cont_clamp_x0", 0.0)) > 0:
x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"])) x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"]))
loss_base = (eps_pred - x0_target) ** 2 loss_base = (eps_pred - x0_target) ** 2
elif cont_target == "v":
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
v_target = torch.sqrt(a_bar_t) * noise - torch.sqrt(1.0 - a_bar_t) * x_cont_resid
loss_base = (eps_pred - v_target) ** 2
else: else:
loss_base = (eps_pred - noise) ** 2 loss_base = (eps_pred - noise) ** 2
@@ -282,17 +329,72 @@ def main():
lam = float(config["lambda"]) lam = float(config["lambda"])
loss = lam * loss_cont + (1 - lam) * loss_disc loss = lam * loss_cont + (1 - lam) * loss_disc
q_weight = float(config.get("quantile_loss_weight", 0.0))
quantile_loss = 0.0
if q_weight > 0:
warmup = int(config.get("quantile_loss_warmup_steps", 0))
if warmup > 0:
q_weight = q_weight * min(1.0, (total_step + 1) / float(warmup))
q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95])
q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype)
# Use normalized space for stable quantiles on x0.
x_real = x_cont_resid
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
if cont_target == "x0":
x_gen = eps_pred
elif cont_target == "v":
v_pred = eps_pred
x_gen = torch.sqrt(a_bar_t) * x_cont_t - torch.sqrt(1.0 - a_bar_t) * v_pred
else:
# eps prediction
x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t)
q_clip = float(config.get("quantile_loss_clip", 0.0))
if q_clip > 0:
x_real = torch.clamp(x_real, -q_clip, q_clip)
x_gen = torch.clamp(x_gen, -q_clip, q_clip)
x_real = x_real.view(-1, x_real.size(-1))
x_gen = x_gen.view(-1, x_gen.size(-1))
q_real = torch.quantile(x_real, q_tensor, dim=0)
q_gen = torch.quantile(x_gen, q_tensor, dim=0)
q_delta = float(config.get("quantile_loss_huber_delta", 0.0))
q_diff = q_gen - q_real
if q_delta > 0:
quantile_loss = torch.nn.functional.smooth_l1_loss(q_gen, q_real, beta=q_delta)
else:
quantile_loss = torch.mean(torch.abs(q_diff))
loss = loss + q_weight * quantile_loss
opt.zero_grad() opt.zero_grad()
loss.backward() loss.backward()
if float(config.get("grad_clip", 0.0)) > 0: if float(config.get("grad_clip", 0.0)) > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"])) torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"]))
opt.step() opt.step()
if opt_temporal is not None:
opt_temporal.zero_grad()
temporal_loss.backward()
if float(config.get("grad_clip", 0.0)) > 0:
torch.nn.utils.clip_grad_norm_(temporal_model.parameters(), float(config["grad_clip"]))
opt_temporal.step()
if ema is not None: if ema is not None:
ema.update(model) ema.update(model)
if step % int(config["log_every"]) == 0: if step % int(config["log_every"]) == 0:
print("epoch", epoch, "step", step, "loss", float(loss)) print("epoch", epoch, "step", step, "loss", float(loss))
with open(log_path, "a", encoding="utf-8") as f: with open(log_path, "a", encoding="utf-8") as f:
if use_quantile:
f.write(
"%d,%d,%.6f,%.6f,%.6f,%.6f\n"
% (
epoch,
step,
float(loss),
float(loss_cont),
float(loss_disc),
float(quantile_loss),
)
)
else:
f.write( f.write(
"%d,%d,%.6f,%.6f,%.6f\n" "%d,%d,%.6f,%.6f,%.6f\n"
% (epoch, step, float(loss), float(loss_cont), float(loss_disc)) % (epoch, step, float(loss), float(loss_cont), float(loss_disc))
@@ -308,11 +410,15 @@ def main():
} }
if ema is not None: if ema is not None:
ckpt["ema"] = ema.state_dict() ckpt["ema"] = ema.state_dict()
if temporal_model is not None:
ckpt["temporal"] = temporal_model.state_dict()
torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt")) torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt"))
torch.save(model.state_dict(), os.path.join(out_dir, "model.pt")) torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
if ema is not None: if ema is not None:
torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt")) torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt"))
if temporal_model is not None:
torch.save(temporal_model.state_dict(), os.path.join(out_dir, "temporal.pt"))
if __name__ == "__main__": if __name__ == "__main__":

365
report.md Normal file
View File

@@ -0,0 +1,365 @@
# Hybrid Diffusion for ICS Traffic (HAI 21.03) — Project Report
# 工业控制系统流量混合扩散生成HAI 21.03)— 项目报告
## 1. Project Goal / 项目目标
Build a **hybrid diffusion-based generator** for industrial control system (ICS) traffic features, targeting **mixed continuous + discrete** feature sequences. The output is **feature-level sequences**, not raw packets. The generator should preserve:
- **Distributional fidelity** (continuous value ranges and discrete frequencies)
- **Temporal consistency** (time correlation and sequence structure)
- **Protocol/field consistency** (for discrete fields)
构建一个用于工业控制系统ICS流量特征的**混合扩散生成模型**,面向**连续+离散混合特征序列**。输出为**特征级序列**而非原始报文。生成结果需要同时保持:
- **分布一致性**(连续值范围与离散取值频率)
- **时序一致性**(时间相关性与序列结构)
- **字段/协议一致性**(离散字段的逻辑一致)
This project is aligned with the STOUTER idea of **structure-aware diffusion** for spatiotemporal data, but applied to **ICS feature sequences** rather than cellular traffic.
本项目呼应 STOUTER 的**结构先验+扩散**思想,但应用于**ICS 特征序列**而非蜂窝流量。
---
## 2. Data and Scope / 数据与范围
**Dataset used in the current implementation:** HAI 21.03 (CSV feature traces)
**当前实现使用的数据集:** HAI 21.03CSV 特征序列)
**Data location (default in config):**
- `dataset/hai/hai-21.03/train*.csv.gz`
**数据位置config 默认):**
- `dataset/hai/hai-21.03/train*.csv.gz`
**Feature split (fixed schema):**
- Defined in `example/feature_split.json`
- **Continuous features:** sensor/process values
- **Discrete features:** binary/low-cardinality status/flag fields
- `time` column is excluded from modeling
**特征拆分(固定 schema**
- `example/feature_split.json`
- **连续特征:** 传感器/过程值
- **离散特征:** 二值/低基数状态字段
- `time` 列不参与训练
---
## 3. End-to-End Pipeline / 端到端流程
**One command pipeline:**
```
python example/run_all.py --device cuda
```
**一键流程:**
```
python example/run_all.py --device cuda
```
### Pipeline stages / 流程阶段
1) **Prepare data** (`example/prepare_data.py`)
2) **Train model** (`example/train.py`)
3) **Generate samples** (`example/export_samples.py`)
4) **Evaluate** (`example/evaluate_generated.py`)
5) **Summarize metrics** (`example/summary_metrics.py`)
1) **数据准备**(统计量与词表)
2) **训练模型**
3) **生成样本并导出**
4) **评估指标**
5) **汇总指标**
---
## 4. Technical Architecture / 技术架构
### 4.1 Hybrid Diffusion Model (Core) / 混合扩散模型(核心)
Defined in `example/hybrid_diffusion.py`.
**Key components:**
- **Continuous branch**: Gaussian diffusion (DDPM style)
- **Discrete branch**: Mask diffusion for categorical tokens
- **Shared backbone**: GRU + residual MLP + LayerNorm
- **Embedding inputs**:
- continuous projection
- discrete embeddings per column
- time embedding (sinusoidal)
- positional embedding (sequence index)
- optional condition embedding (`file_id`)
**Outputs:**
- Continuous head: predicts target (`eps`, `x0`, or `v`)
- Discrete heads: predict logits for each discrete column
**核心组成:**
- **连续分支:** 高斯扩散DDPM
- **离散分支:** Mask 扩散
- **共享主干:** GRU + 残差 MLP + LayerNorm
- **输入嵌入:**
- 连续投影
- 离散字段嵌入
- 时间嵌入(正弦)
- 位置嵌入(序列索引)
- 条件嵌入(可选,`file_id`
**输出:**
- 连续 head预测 `eps/x0/v`
- 离散 head各字段 logits
---
### 4.2 Feature Graph Mixer (Structure Prior) / 特征图混合器(结构先验)
Implemented in `example/hybrid_diffusion.py` as `FeatureGraphMixer`.
Purpose: inject **learnable feature-dependency prior** without dataset-specific hardcoding.
**Mechanism:**
- Learns a dense feature relation matrix `A`
- Applies: `x + x @ A`
- Symmetric stabilizing constraint: `(A + A^T)/2`
- Controlled by scale and dropout
**Config:**
```
"model_use_feature_graph": true,
"feature_graph_scale": 0.1,
"feature_graph_dropout": 0.0
```
**目的:**在不写死特定数据集关系的情况下,引入**可学习特征依赖先验**
**机制:**
- 学习稠密关系矩阵 `A`
- 特征混合:`x + x @ A`
- 对称化稳定:`(A + A^T)/2`
- 通过 scale/dropout 控制强度
---
### 4.3 Two-Stage Temporal Backbone / 两阶段时序骨干
Stage-1 uses a **GRU temporal generator** to model sequence trend in normalized space. Stage-2 diffusion then models the **residual** (x trend). This decouples temporal consistency from distribution alignment.
第一阶段使用 **GRU 时序生成器**在归一化空间建模序列趋势;第二阶段扩散模型学习**残差**x trend实现时序一致性与分布对齐的解耦。
---
## 5. Diffusion Formulations / 扩散建模形式
### 5.1 Continuous Diffusion / 连续扩散
Forward process:
```
x_t = sqrt(a_bar_t) * x_0 + sqrt(1 - a_bar_t) * eps
```
Targets supported:
- **eps prediction** (standard DDPM)
- **x0 prediction** (direct reconstruction)
- **v prediction** (v = sqrt(a_bar)*eps sqrt(1-a_bar)*x0)
Current config default:
```
"cont_target": "v"
```
Sampling uses the target to reconstruct `eps` and apply standard DDPM reverse update.
**前向扩散:**如上公式。
**支持的目标:**
- `eps`(噪声预测)
- `x0`(原样本预测)
- `v`vprediction
**当前默认:**`cont_target = v`
**采样:**根据目标反解 `eps` 再执行标准 DDPM 反向步骤。
---
### 5.2 Discrete Diffusion (Mask) / 离散扩散Mask
Forward process: replace tokens with `[MASK]` using cosine schedule:
```
p(t) = 0.5 * (1 - cos(pi * t / T))
```
Optional scale: `disc_mask_scale`
Reverse process: cross-entropy on masked positions only.
**前向:**按 cosine schedule 进行 Mask。
**反向:**仅在 mask 位置计算交叉熵。
---
## 6. Loss Design (Current) / 当前损失设计
Total loss:
```
L = λ * L_cont + (1 λ) * L_disc + w_q * L_quantile
```
### 6.1 Continuous Loss / 连续损失
Depending on `cont_target`:
- eps target: MSE(eps_pred, eps)
- x0 target: MSE(x0_pred, x0)
- v target: MSE(v_pred, v_target)
Optional inverse-variance weighting:
```
cont_loss_weighting = "inv_std"
```
### 6.2 Discrete Loss / 离散损失
Cross-entropy on masked positions only.
### 6.3 Quantile Loss (Distribution Alignment) / 分位数损失(分布对齐)
Added to improve KS (distribution shape alignment):
- Compute quantiles on generated vs real x0
- Loss = Huber or L1 difference on quantiles
Stabilization:
```
quantile_loss_warmup_steps
quantile_loss_clip
quantile_loss_huber_delta
```
---
## 7. Training Strategy / 训练策略
Defined in `example/train.py`.
**Key techniques:**
- EMA of model weights
- Gradient clipping
- Shuffle buffer to reduce batch bias
- Optional feature graph prior
- Quantile loss warmup for stability
- Optional stage-1 temporal GRU (trend) + residual diffusion
**Config highlights (example/config.json):**
```
timesteps: 600
batch_size: 128
seq_len: 128
epochs: 10
max_batches: 4000
lambda: 0.7
cont_target: "v"
quantile_loss_weight: 0.1
model_use_feature_graph: true
use_temporal_stage1: true
```
---
## 8. Sampling & Export / 采样与导出
Defined in:
- `example/sample.py`
- `example/export_samples.py`
**Export steps:**
- Reverse diffusion with conditional sampling
- Reverse normalize continuous values
- Clamp to observed min/max
- Restore discrete tokens from vocab
- Write to CSV
---
## 9. Evaluation Metrics / 评估指标
Implemented in `example/evaluate_generated.py`.
### Continuous Metrics / 连续指标
- **KS statistic** (distribution similarity per feature)
- **Quantile errors** (q05/q25/q50/q75/q95)
- **Lag1 correlation diff** (temporal structure)
### Discrete Metrics / 离散指标
- **JSD** over token frequency distribution
- **Invalid token counts**
### Summary Metrics / 汇总指标
Auto-logged in:
- `example/results/metrics_history.csv`
- via `example/summary_metrics.py`
---
## 10. Automation / 自动化
### Oneclick pipeline / 一键流程
```
python example/run_all.py --device cuda
```
### Metrics logging / 指标记录
Each run appends:
```
timestamp,avg_ks,avg_jsd,avg_lag1_diff
```
---
## 11. Key Engineering Decisions / 关键工程决策
### 11.1 Mixed-Type Diffusion / 混合类型扩散
Continuous + discrete handled separately to respect data types.
### 11.2 Structure Prior / 结构先验
Learnable feature graph added to encode implicit dependencies.
### 11.3 vprediction
Chosen to stabilize training and improve convergence in diffusion.
### 11.4 Distribution Alignment / 分布对齐
Quantile loss introduced to directly reduce KS.
---
## 12. Known Issues / Current Limitations / 已知问题与当前局限
- **KS remains high** in many experiments, meaning continuous distributions are still misaligned.
- **Lag1 may degrade** when quantile loss is too strong.
- **Loss spikes** observed when quantile loss is unstable (mitigated with warmup + clip + Huber).
**当前问题:**
- KS 高,说明连续分布仍未对齐
- 分位数损失过强时会损害时序相关性
- 分位数损失不稳定时会出现 loss 爆炸(已引入 warmup/clip/Huber
---
## 13. Suggested Next Steps (Research Roadmap) / 下一步建议(研究路线)
1) **SNR-weighted loss** (improve stability across timesteps)
2) **Two-stage training** (distribution first, temporal consistency second)
3) **Upgrade discrete diffusion** (D3PM-style transitions)
4) **Structured conditioning** (state/phase conditioning)
5) **Graph-based priors** (explicit feature/plant dependency graphs)
---
## 14. Code Map (Key Files) / 代码索引(关键文件)
**Core model**
- `example/hybrid_diffusion.py`
**Training**
- `example/train.py`
**Sampling & export**
- `example/sample.py`
- `example/export_samples.py`
**Pipeline**
- `example/run_all.py`
**Evaluation**
- `example/evaluate_generated.py`
- `example/summary_metrics.py`
**Configs**
- `example/config.json`
---
## 15. Summary / 总结
This project implements a **hybrid diffusion model for ICS traffic features**, combining continuous Gaussian diffusion with discrete mask diffusion, enhanced with a **learnable feature-graph prior**. The system includes a full pipeline for preparation, training, sampling, exporting, and evaluation. Key research challenges remain in **distribution alignment (KS)** and **joint optimization of distribution fidelity vs temporal consistency**, motivating future improvements such as SNR-weighted loss, staged training, and stronger structural priors.
本项目实现了用于 ICS 流量特征的**混合扩散模型**,将连续高斯扩散与离散 Mask 扩散结合,并引入**可学习特征图先验**。系统包含完整的数据准备、训练、采样、导出与评估流程。当前研究挑战集中在**连续分布对齐KS**与**分布/时序一致性之间的权衡**,后续可通过 SNRweighted loss、分阶段训练与更强结构先验继续改进。