Compare commits

9 Commits

Author SHA1 Message Date
dd4c1e171f update新结构 2026-01-26 19:00:16 +08:00
MZ YANG
f8edee9510 update 2026-01-26 18:36:16 +08:00
cb610281ce update新结构 2026-01-26 18:27:41 +08:00
bc838d7cd7 update ks 2026-01-25 18:13:37 +08:00
MZ YANG
b3c45010a4 update 2026-01-25 18:10:08 +08:00
2a1a9a05c6 update ks 2026-01-25 18:00:28 +08:00
cc10125fbf update ks 2026-01-25 17:55:28 +08:00
MZ YANG
5ef1e465f9 update 2026-01-25 17:50:34 +08:00
5aba14d511 update ks 2026-01-25 17:39:28 +08:00
14 changed files with 2329 additions and 1673 deletions

View File

@@ -72,3 +72,5 @@ python example/run_pipeline.py --device auto
- The script only samples the first 5000 rows to stay fast. - The script only samples the first 5000 rows to stay fast.
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it. - `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.
- `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU. - `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU.
- Optional feature-graph mixer (`model_use_feature_graph`) adds a learnable relation prior across feature channels.
- Optional two-stage temporal backbone (`use_temporal_stage1`) learns a GRU-based sequence trend; diffusion models the residual.

View File

@@ -32,13 +32,24 @@
"model_ff_mult": 2, "model_ff_mult": 2,
"model_pos_dim": 64, "model_pos_dim": 64,
"model_use_pos_embed": true, "model_use_pos_embed": true,
"model_use_feature_graph": true,
"feature_graph_scale": 0.1,
"feature_graph_dropout": 0.0,
"use_temporal_stage1": true,
"temporal_hidden_dim": 256,
"temporal_num_layers": 1,
"temporal_dropout": 0.0,
"temporal_loss_weight": 1.0,
"disc_mask_scale": 0.9, "disc_mask_scale": 0.9,
"cont_loss_weighting": "inv_std", "cont_loss_weighting": "inv_std",
"cont_loss_eps": 1e-6, "cont_loss_eps": 1e-6,
"cont_target": "x0", "cont_target": "v",
"cont_clamp_x0": 5.0, "cont_clamp_x0": 5.0,
"quantile_loss_weight": 0.1, "quantile_loss_weight": 0.1,
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95], "quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
"quantile_loss_warmup_steps": 200,
"quantile_loss_clip": 6.0,
"quantile_loss_huber_delta": 1.0,
"shuffle_buffer": 256, "shuffle_buffer": 256,
"sample_batch_size": 8, "sample_batch_size": 8,
"sample_seq_len": 128 "sample_seq_len": 128

View File

@@ -13,7 +13,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from data_utils import load_split from data_utils import load_split
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
@@ -140,6 +140,10 @@ def main():
raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob) raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob)
cont_target = str(cfg.get("cont_target", "eps")) cont_target = str(cfg.get("cont_target", "eps"))
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0)) cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
model = HybridDiffusionModel( model = HybridDiffusionModel(
cont_dim=len(cont_cols), cont_dim=len(cont_cols),
@@ -151,6 +155,9 @@ def main():
ff_mult=int(cfg.get("model_ff_mult", 2)), ff_mult=int(cfg.get("model_ff_mult", 2)),
pos_dim=int(cfg.get("model_pos_dim", 64)), pos_dim=int(cfg.get("model_pos_dim", 64)),
use_pos_embed=bool(cfg.get("model_use_pos_embed", True)), use_pos_embed=bool(cfg.get("model_use_pos_embed", True)),
use_feature_graph=bool(cfg.get("model_use_feature_graph", False)),
feature_graph_scale=float(cfg.get("feature_graph_scale", 0.1)),
feature_graph_dropout=float(cfg.get("feature_graph_dropout", 0.0)),
cond_vocab_size=cond_vocab_size if use_condition else 0, cond_vocab_size=cond_vocab_size if use_condition else 0,
cond_dim=int(cfg.get("cond_dim", 32)), cond_dim=int(cfg.get("cond_dim", 32)),
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)), use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
@@ -163,6 +170,20 @@ def main():
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True)) model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
model.eval() model.eval()
temporal_model = None
if use_temporal_stage1:
temporal_model = TemporalGRUGenerator(
input_dim=len(cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_num_layers,
dropout=temporal_dropout,
).to(device)
temporal_path = Path(args.model_path).with_name("temporal.pt")
if not temporal_path.exists():
raise SystemExit(f"missing temporal model file: {temporal_path}")
temporal_model.load_state_dict(torch.load(temporal_path, map_location=device, weights_only=True))
temporal_model.eval()
betas = cosine_beta_schedule(args.timesteps).to(device) betas = cosine_beta_schedule(args.timesteps).to(device)
alphas = 1.0 - betas alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod = torch.cumprod(alphas, dim=0)
@@ -189,6 +210,10 @@ def main():
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long) cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
cond = cond_id cond = cond_id
trend = None
if temporal_model is not None:
trend = temporal_model.generate(args.batch_size, args.seq_len, device)
for t in reversed(range(args.timesteps)): for t in reversed(range(args.timesteps)):
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long) t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch, cond) eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
@@ -201,6 +226,10 @@ def main():
if cont_clamp_x0 > 0: if cont_clamp_x0 > 0:
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0) x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t) eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
elif cont_target == "v":
v_pred = eps_pred
x0_pred = torch.sqrt(a_bar_t) * x_cont - torch.sqrt(1.0 - a_bar_t) * v_pred
eps_pred = torch.sqrt(1.0 - a_bar_t) * x_cont + torch.sqrt(a_bar_t) * v_pred
coef1 = 1.0 / torch.sqrt(a_t) coef1 = 1.0 / torch.sqrt(a_t)
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t) coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
mean_x = coef1 * (x_cont - coef2 * eps_pred) mean_x = coef1 * (x_cont - coef2 * eps_pred)
@@ -225,6 +254,8 @@ def main():
) )
x_disc[:, :, i][mask] = sampled[mask] x_disc[:, :, i][mask] = sampled[mask]
if trend is not None:
x_cont = x_cont + trend
# move to CPU for export # move to CPU for export
x_cont = x_cont.cpu() x_cont = x_cont.cpu()
x_disc = x_disc.cpu() x_disc = x_disc.cpu()

View File

@@ -66,6 +66,64 @@ class SinusoidalTimeEmbedding(nn.Module):
return emb return emb
class FeatureGraphMixer(nn.Module):
"""Learnable feature relation mixer (dataset-agnostic)."""
def __init__(self, dim: int, scale: float = 0.1, dropout: float = 0.0):
super().__init__()
self.scale = scale
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.A = nn.Parameter(torch.zeros(dim, dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, T, D)
# Symmetric relation to stabilize
A = (self.A + self.A.t()) * 0.5
mixed = torch.matmul(x, A) * self.scale
mixed = self.dropout(mixed)
return x + mixed
class TemporalGRUGenerator(nn.Module):
"""Stage-1 temporal generator (autoregressive GRU) for sequence backbone."""
def __init__(self, input_dim: int, hidden_dim: int = 256, num_layers: int = 1, dropout: float = 0.0):
super().__init__()
self.start_token = nn.Parameter(torch.zeros(input_dim))
self.gru = nn.GRU(
input_dim,
hidden_dim,
num_layers=num_layers,
dropout=dropout if num_layers > 1 else 0.0,
batch_first=True,
)
self.out = nn.Linear(hidden_dim, input_dim)
def forward_teacher(self, x: torch.Tensor) -> torch.Tensor:
"""Teacher-forced next-step prediction. Returns trend sequence and preds."""
if x.size(1) < 2:
raise ValueError("sequence length must be >= 2 for teacher forcing")
inp = x[:, :-1, :]
out, _ = self.gru(inp)
pred_next = self.out(out)
trend = torch.zeros_like(x)
trend[:, 0, :] = x[:, 0, :]
trend[:, 1:, :] = pred_next
return trend, pred_next
def generate(self, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
"""Autoregressively generate a backbone sequence."""
h = None
prev = self.start_token.unsqueeze(0).expand(batch_size, -1).to(device)
outputs = []
for _ in range(seq_len):
out, h = self.gru(prev.unsqueeze(1), h)
nxt = self.out(out.squeeze(1))
outputs.append(nxt.unsqueeze(1))
prev = nxt
return torch.cat(outputs, dim=1)
class HybridDiffusionModel(nn.Module): class HybridDiffusionModel(nn.Module):
def __init__( def __init__(
self, self,
@@ -78,6 +136,9 @@ class HybridDiffusionModel(nn.Module):
ff_mult: int = 2, ff_mult: int = 2,
pos_dim: int = 64, pos_dim: int = 64,
use_pos_embed: bool = True, use_pos_embed: bool = True,
use_feature_graph: bool = False,
feature_graph_scale: float = 0.1,
feature_graph_dropout: float = 0.0,
cond_vocab_size: int = 0, cond_vocab_size: int = 0,
cond_dim: int = 32, cond_dim: int = 32,
use_tanh_eps: bool = False, use_tanh_eps: bool = False,
@@ -92,6 +153,7 @@ class HybridDiffusionModel(nn.Module):
self.eps_scale = eps_scale self.eps_scale = eps_scale
self.pos_dim = pos_dim self.pos_dim = pos_dim
self.use_pos_embed = use_pos_embed self.use_pos_embed = use_pos_embed
self.use_feature_graph = use_feature_graph
self.cond_vocab_size = cond_vocab_size self.cond_vocab_size = cond_vocab_size
self.cond_dim = cond_dim self.cond_dim = cond_dim
@@ -106,8 +168,17 @@ class HybridDiffusionModel(nn.Module):
disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds) disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds)
self.cont_proj = nn.Linear(cont_dim, cont_dim) self.cont_proj = nn.Linear(cont_dim, cont_dim)
self.feature_dim = cont_dim + disc_embed_dim
if use_feature_graph:
self.feature_graph = FeatureGraphMixer(
self.feature_dim,
scale=feature_graph_scale,
dropout=feature_graph_dropout,
)
else:
self.feature_graph = None
pos_dim = pos_dim if use_pos_embed else 0 pos_dim = pos_dim if use_pos_embed else 0
in_dim = cont_dim + disc_embed_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0) in_dim = self.feature_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0)
self.in_proj = nn.Linear(in_dim, hidden_dim) self.in_proj = nn.Linear(in_dim, hidden_dim)
self.backbone = nn.GRU( self.backbone = nn.GRU(
hidden_dim, hidden_dim,
@@ -149,7 +220,10 @@ class HybridDiffusionModel(nn.Module):
cond_feat = self.cond_embed(cond).unsqueeze(1).expand(-1, x_cont.size(1), -1) cond_feat = self.cond_embed(cond).unsqueeze(1).expand(-1, x_cont.size(1), -1)
cont_feat = self.cont_proj(x_cont) cont_feat = self.cont_proj(x_cont)
parts = [cont_feat, disc_feat, time_emb] feat = torch.cat([cont_feat, disc_feat], dim=-1)
if self.feature_graph is not None:
feat = self.feature_graph(feat)
parts = [feat, time_emb]
if pos_emb is not None: if pos_emb is not None:
parts.append(pos_emb.unsqueeze(0).expand(x_cont.size(0), -1, -1)) parts.append(pos_emb.unsqueeze(0).expand(x_cont.size(0), -1, -1))
if cond_feat is not None: if cond_feat is not None:

View File

@@ -34,6 +34,7 @@ def main():
loss = [] loss = []
loss_cont = [] loss_cont = []
loss_disc = [] loss_disc = []
loss_quant = []
with log_path.open("r", encoding="utf-8", newline="") as f: with log_path.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f) reader = csv.DictReader(f)
@@ -42,6 +43,8 @@ def main():
loss.append(float(row["loss"])) loss.append(float(row["loss"]))
loss_cont.append(float(row["loss_cont"])) loss_cont.append(float(row["loss_cont"]))
loss_disc.append(float(row["loss_disc"])) loss_disc.append(float(row["loss_disc"]))
if "loss_quantile" in row:
loss_quant.append(float(row["loss_quantile"]))
if not steps: if not steps:
raise SystemExit("no rows in log file: %s" % log_path) raise SystemExit("no rows in log file: %s" % log_path)
@@ -50,6 +53,8 @@ def main():
plt.plot(steps, loss, label="total") plt.plot(steps, loss, label="total")
plt.plot(steps, loss_cont, label="continuous") plt.plot(steps, loss_cont, label="continuous")
plt.plot(steps, loss_disc, label="discrete") plt.plot(steps, loss_disc, label="discrete")
if loss_quant:
plt.plot(steps, loss_quant, label="quantile")
plt.xlabel("step") plt.xlabel("step")
plt.ylabel("loss") plt.ylabel("loss")
plt.title("Training Loss") plt.title("Training Loss")

View File

@@ -32,11 +32,14 @@
"model_ff_mult": 2, "model_ff_mult": 2,
"model_pos_dim": 64, "model_pos_dim": 64,
"model_use_pos_embed": true, "model_use_pos_embed": true,
"model_use_feature_graph": true,
"feature_graph_scale": 0.1,
"feature_graph_dropout": 0.0,
"disc_mask_scale": 0.9, "disc_mask_scale": 0.9,
"shuffle_buffer": 256, "shuffle_buffer": 256,
"cont_loss_weighting": "inv_std", "cont_loss_weighting": "inv_std",
"cont_loss_eps": 1e-06, "cont_loss_eps": 1e-06,
"cont_target": "x0", "cont_target": "v",
"cont_clamp_x0": 5.0, "cont_clamp_x0": 5.0,
"quantile_loss_weight": 0.1, "quantile_loss_weight": 0.1,
"quantile_points": [ "quantile_points": [
@@ -46,6 +49,9 @@
0.75, 0.75,
0.95 0.95
], ],
"quantile_loss_warmup_steps": 200,
"quantile_loss_clip": 6.0,
"quantile_loss_huber_delta": 1.0,
"sample_batch_size": 8, "sample_batch_size": 8,
"sample_seq_len": 128 "sample_seq_len": 128
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,61 +1,61 @@
epoch,step,loss,loss_cont,loss_disc epoch,step,loss,loss_cont,loss_disc,loss_quantile
0,0,9801.515625,14001.333008,1.724242 0,0,11527.750000,16467.474609,1.724242,0.113109
0,10,8648.471680,12354.386719,1.050416 0,10,8955.048828,12792.469727,1.065955,0.142707
0,20,6285.875000,8979.247070,0.775979 0,20,123049.507812,175784.562500,1.048191,0.181546
0,30,6297.153809,8995.444336,0.598639 0,30,30947.671875,44210.511719,1.034107,0.224363
0,40,7421.842285,10602.228516,0.402729 0,40,11166.339844,15951.691406,0.512211,0.163978
0,50,8768.145508,12525.551758,0.339361 0,50,14919.276367,21313.089844,0.369016,0.101041
1,0,17362.210938,24802.792969,0.521414 1,0,1113425.000000,1590606.875000,0.667305,0.236102
1,10,9998.125000,14282.763672,0.346019 1,10,38804.527344,55433.980469,2.453400,0.238323
1,20,6025.115234,8606.978516,0.211357 1,20,138075.984375,197251.281250,0.288013,0.157941
1,30,5529.321777,7898.669922,0.321453 1,30,56904.078125,81291.343750,0.429261,0.174654
1,40,4561.373047,6515.895508,0.306940 1,40,10662.019531,15231.303711,0.343701,0.079876
1,50,4133.944336,5905.306641,0.271862 1,50,9890.013672,14128.458008,0.292848,0.088391
2,0,18390.314453,26271.496094,0.559547 2,0,912398.500000,1303426.125000,0.795286,0.256503
2,10,5629.449707,8041.824219,0.315514 2,10,12546.185547,17922.542969,1.328752,0.110444
2,20,5088.290527,7268.675293,0.209568 2,20,109676.710938,156680.906250,0.245523,0.132549
2,30,4109.000000,5869.645020,0.316234 2,30,49427.507812,70610.578125,0.327188,0.088591
2,40,4135.112305,5906.964355,0.288305 2,40,27778.673828,39683.660156,0.345683,0.102361
2,50,3574.877686,5106.638672,0.300608 2,50,10311.509766,14730.566406,0.350199,0.091674
3,0,16117.513672,23024.658203,0.541584 3,0,1040308.062500,1486154.000000,0.807157,0.279805
3,10,7299.515137,10427.646484,0.307387 3,10,64799.246094,92570.117188,0.485949,0.198213
3,20,3928.072998,5611.221680,0.202967 3,20,336018.000000,480025.531250,0.410466,0.118048
3,30,4386.781250,6266.497559,0.325895 3,30,94216.312500,134594.562500,0.355044,0.114209
3,40,3627.836182,5182.310059,0.271281 3,40,19988.919922,28555.457031,0.298066,0.094291
3,50,3513.194580,5018.536133,0.291253 3,50,9181.969727,13116.940430,0.326489,0.137392
4,0,14937.586914,21339.070312,0.552027 4,0,741176.187500,1058822.750000,0.850633,0.233118
4,10,6087.895508,8696.762695,0.334634 4,10,39252.617188,56074.410156,1.690607,0.227370
4,20,3961.117676,5658.443359,0.210930 4,20,108992.304688,155703.140625,0.258920,0.285944
4,30,3405.418457,4864.548828,0.317895 4,30,37115.253906,53021.632812,0.337674,0.131537
4,40,3483.671631,4976.360840,0.296246 4,40,19358.708984,27655.123047,0.349937,0.169446
4,50,2833.118164,4047.002441,0.297401 4,50,11434.291992,16334.540039,0.351904,0.089116
5,0,12412.599609,17731.978516,0.492310 5,0,845658.312500,1208083.000000,0.922483,0.285983
5,10,4952.285156,7074.484375,0.294987 5,10,130569.406250,186527.515625,0.405198,0.227639
5,20,4023.841309,5748.041504,0.219470 5,20,245780.390625,351114.687500,0.301236,0.147030
5,30,3416.583740,4880.517090,0.310916 5,30,42017.671875,60025.066406,0.372036,0.117895
5,40,3283.848389,4690.912598,0.274809 5,40,11496.740234,16423.779297,0.286911,0.085701
5,50,2953.851807,4219.480469,0.286249 5,50,8891.728516,12702.317383,0.322913,0.099181
6,0,9772.470703,13960.334961,0.525315 6,0,617909.687500,882727.812500,0.834663,0.205251
6,10,4856.467773,6937.604980,0.294435 6,10,18171.734375,25959.302734,0.677465,0.190570
6,20,3487.783203,4982.249023,0.249389 6,20,423716.187500,605308.687500,0.464189,0.156125
6,30,2907.010498,4152.563965,0.302258 6,30,48133.507812,68761.914062,0.478786,0.221465
6,40,2978.796875,4255.132324,0.272516 6,40,20350.281250,29071.666016,0.343182,0.118739
6,50,2954.490723,4220.402832,0.260484 6,50,11372.219727,16245.889648,0.289177,0.101218
7,0,5634.914062,8049.578125,0.477380 7,0,534133.500000,763047.500000,0.732465,0.239962
7,10,4834.394531,6906.059570,0.314162 7,10,19025.574219,27178.703125,1.551094,0.183023
7,20,2799.942871,3999.613770,0.267521 7,20,58042.757812,82918.078125,0.298854,0.146532
7,30,2899.989990,4142.533203,0.294986 7,30,11810.656250,16872.201172,0.353035,0.092791
7,40,2961.559570,4230.455078,0.337903 7,40,9733.656250,13905.077148,0.312043,0.089972
7,50,3053.434814,4361.746094,0.263959 7,50,9069.159180,12955.803711,0.280795,0.129749
8,0,5015.993652,7165.385742,0.495864 8,0,650366.937500,929095.250000,0.787194,0.263791
8,10,3965.379639,5664.615234,0.316581 8,10,13121.248047,18744.031250,1.376233,0.141053
8,20,3669.429688,5241.729980,0.204997 8,20,19326.507812,27609.167969,0.243424,0.183426
8,30,2815.938232,4022.476562,0.271451 8,30,12904.376953,18434.667969,0.334577,0.092547
8,40,2967.452881,4238.926758,0.263370 8,40,9682.833984,13832.478516,0.296229,0.109518
8,50,2930.122314,4185.593262,0.262396 8,50,8866.144531,12665.773438,0.312481,0.099564
9,0,4364.022461,6233.995605,0.496020 9,0,265689.562500,379556.156250,0.738840,0.186122
9,10,4222.906250,6032.508301,0.319415 9,10,9944.607422,14206.222656,0.806142,0.096763
9,20,3070.776367,4386.530762,0.233229 9,20,22091.175781,31558.679688,0.285403,0.131950
9,30,2839.424805,4056.029785,0.265954 9,30,19346.830078,27638.111328,0.460825,0.131496
9,40,2770.363770,3957.375000,0.264960 9,40,10279.702148,14685.139648,0.314872,0.095758
9,50,2557.437256,3653.188721,0.265549 9,50,10420.340820,14886.050781,0.314008,0.112904
1 epoch step loss loss_cont loss_disc loss_quantile
2 0 0 9801.515625 11527.750000 14001.333008 16467.474609 1.724242 0.113109
3 0 10 8648.471680 8955.048828 12354.386719 12792.469727 1.050416 1.065955 0.142707
4 0 20 6285.875000 123049.507812 8979.247070 175784.562500 0.775979 1.048191 0.181546
5 0 30 6297.153809 30947.671875 8995.444336 44210.511719 0.598639 1.034107 0.224363
6 0 40 7421.842285 11166.339844 10602.228516 15951.691406 0.402729 0.512211 0.163978
7 0 50 8768.145508 14919.276367 12525.551758 21313.089844 0.339361 0.369016 0.101041
8 1 0 17362.210938 1113425.000000 24802.792969 1590606.875000 0.521414 0.667305 0.236102
9 1 10 9998.125000 38804.527344 14282.763672 55433.980469 0.346019 2.453400 0.238323
10 1 20 6025.115234 138075.984375 8606.978516 197251.281250 0.211357 0.288013 0.157941
11 1 30 5529.321777 56904.078125 7898.669922 81291.343750 0.321453 0.429261 0.174654
12 1 40 4561.373047 10662.019531 6515.895508 15231.303711 0.306940 0.343701 0.079876
13 1 50 4133.944336 9890.013672 5905.306641 14128.458008 0.271862 0.292848 0.088391
14 2 0 18390.314453 912398.500000 26271.496094 1303426.125000 0.559547 0.795286 0.256503
15 2 10 5629.449707 12546.185547 8041.824219 17922.542969 0.315514 1.328752 0.110444
16 2 20 5088.290527 109676.710938 7268.675293 156680.906250 0.209568 0.245523 0.132549
17 2 30 4109.000000 49427.507812 5869.645020 70610.578125 0.316234 0.327188 0.088591
18 2 40 4135.112305 27778.673828 5906.964355 39683.660156 0.288305 0.345683 0.102361
19 2 50 3574.877686 10311.509766 5106.638672 14730.566406 0.300608 0.350199 0.091674
20 3 0 16117.513672 1040308.062500 23024.658203 1486154.000000 0.541584 0.807157 0.279805
21 3 10 7299.515137 64799.246094 10427.646484 92570.117188 0.307387 0.485949 0.198213
22 3 20 3928.072998 336018.000000 5611.221680 480025.531250 0.202967 0.410466 0.118048
23 3 30 4386.781250 94216.312500 6266.497559 134594.562500 0.325895 0.355044 0.114209
24 3 40 3627.836182 19988.919922 5182.310059 28555.457031 0.271281 0.298066 0.094291
25 3 50 3513.194580 9181.969727 5018.536133 13116.940430 0.291253 0.326489 0.137392
26 4 0 14937.586914 741176.187500 21339.070312 1058822.750000 0.552027 0.850633 0.233118
27 4 10 6087.895508 39252.617188 8696.762695 56074.410156 0.334634 1.690607 0.227370
28 4 20 3961.117676 108992.304688 5658.443359 155703.140625 0.210930 0.258920 0.285944
29 4 30 3405.418457 37115.253906 4864.548828 53021.632812 0.317895 0.337674 0.131537
30 4 40 3483.671631 19358.708984 4976.360840 27655.123047 0.296246 0.349937 0.169446
31 4 50 2833.118164 11434.291992 4047.002441 16334.540039 0.297401 0.351904 0.089116
32 5 0 12412.599609 845658.312500 17731.978516 1208083.000000 0.492310 0.922483 0.285983
33 5 10 4952.285156 130569.406250 7074.484375 186527.515625 0.294987 0.405198 0.227639
34 5 20 4023.841309 245780.390625 5748.041504 351114.687500 0.219470 0.301236 0.147030
35 5 30 3416.583740 42017.671875 4880.517090 60025.066406 0.310916 0.372036 0.117895
36 5 40 3283.848389 11496.740234 4690.912598 16423.779297 0.274809 0.286911 0.085701
37 5 50 2953.851807 8891.728516 4219.480469 12702.317383 0.286249 0.322913 0.099181
38 6 0 9772.470703 617909.687500 13960.334961 882727.812500 0.525315 0.834663 0.205251
39 6 10 4856.467773 18171.734375 6937.604980 25959.302734 0.294435 0.677465 0.190570
40 6 20 3487.783203 423716.187500 4982.249023 605308.687500 0.249389 0.464189 0.156125
41 6 30 2907.010498 48133.507812 4152.563965 68761.914062 0.302258 0.478786 0.221465
42 6 40 2978.796875 20350.281250 4255.132324 29071.666016 0.272516 0.343182 0.118739
43 6 50 2954.490723 11372.219727 4220.402832 16245.889648 0.260484 0.289177 0.101218
44 7 0 5634.914062 534133.500000 8049.578125 763047.500000 0.477380 0.732465 0.239962
45 7 10 4834.394531 19025.574219 6906.059570 27178.703125 0.314162 1.551094 0.183023
46 7 20 2799.942871 58042.757812 3999.613770 82918.078125 0.267521 0.298854 0.146532
47 7 30 2899.989990 11810.656250 4142.533203 16872.201172 0.294986 0.353035 0.092791
48 7 40 2961.559570 9733.656250 4230.455078 13905.077148 0.337903 0.312043 0.089972
49 7 50 3053.434814 9069.159180 4361.746094 12955.803711 0.263959 0.280795 0.129749
50 8 0 5015.993652 650366.937500 7165.385742 929095.250000 0.495864 0.787194 0.263791
51 8 10 3965.379639 13121.248047 5664.615234 18744.031250 0.316581 1.376233 0.141053
52 8 20 3669.429688 19326.507812 5241.729980 27609.167969 0.204997 0.243424 0.183426
53 8 30 2815.938232 12904.376953 4022.476562 18434.667969 0.271451 0.334577 0.092547
54 8 40 2967.452881 9682.833984 4238.926758 13832.478516 0.263370 0.296229 0.109518
55 8 50 2930.122314 8866.144531 4185.593262 12665.773438 0.262396 0.312481 0.099564
56 9 0 4364.022461 265689.562500 6233.995605 379556.156250 0.496020 0.738840 0.186122
57 9 10 4222.906250 9944.607422 6032.508301 14206.222656 0.319415 0.806142 0.096763
58 9 20 3070.776367 22091.175781 4386.530762 31558.679688 0.233229 0.285403 0.131950
59 9 30 2839.424805 19346.830078 4056.029785 27638.111328 0.265954 0.460825 0.131496
60 9 40 2770.363770 10279.702148 3957.375000 14685.139648 0.264960 0.314872 0.095758
61 9 50 2557.437256 10420.340820 3653.188721 14886.050781 0.265549 0.314008 0.112904

View File

@@ -75,6 +75,7 @@ def main():
run([sys.executable, str(base_dir / "evaluate_generated.py"), "--reference", str(ref)]) run([sys.executable, str(base_dir / "evaluate_generated.py"), "--reference", str(ref)])
else: else:
run([sys.executable, str(base_dir / "evaluate_generated.py")]) run([sys.executable, str(base_dir / "evaluate_generated.py")])
run([sys.executable, str(base_dir / "summary_metrics.py")])
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -11,6 +11,7 @@ import torch.nn.functional as F
from data_utils import load_split from data_utils import load_split
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
from hybrid_diffusion import TemporalGRUGenerator
from platform_utils import resolve_device, safe_path, ensure_dir from platform_utils import resolve_device, safe_path, ensure_dir
BASE_DIR = Path(__file__).resolve().parent BASE_DIR = Path(__file__).resolve().parent
@@ -56,6 +57,13 @@ def main():
model_ff_mult = int(cfg.get("model_ff_mult", 2)) model_ff_mult = int(cfg.get("model_ff_mult", 2))
model_pos_dim = int(cfg.get("model_pos_dim", 64)) model_pos_dim = int(cfg.get("model_pos_dim", 64))
model_use_pos = bool(cfg.get("model_use_pos_embed", True)) model_use_pos = bool(cfg.get("model_use_pos_embed", True))
model_use_feature_graph = bool(cfg.get("model_use_feature_graph", False))
feature_graph_scale = float(cfg.get("feature_graph_scale", 0.1))
feature_graph_dropout = float(cfg.get("feature_graph_dropout", 0.0))
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
split = load_split(str(SPLIT_PATH)) split = load_split(str(SPLIT_PATH))
time_col = split.get("time_column", "time") time_col = split.get("time_column", "time")
@@ -83,6 +91,9 @@ def main():
ff_mult=model_ff_mult, ff_mult=model_ff_mult,
pos_dim=model_pos_dim, pos_dim=model_pos_dim,
use_pos_embed=model_use_pos, use_pos_embed=model_use_pos,
use_feature_graph=model_use_feature_graph,
feature_graph_scale=feature_graph_scale,
feature_graph_dropout=feature_graph_dropout,
cond_vocab_size=cond_vocab_size, cond_vocab_size=cond_vocab_size,
cond_dim=cond_dim, cond_dim=cond_dim,
use_tanh_eps=use_tanh_eps, use_tanh_eps=use_tanh_eps,
@@ -92,6 +103,20 @@ def main():
model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True)) model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True))
model.eval() model.eval()
temporal_model = None
if use_temporal_stage1:
temporal_model = TemporalGRUGenerator(
input_dim=len(cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_num_layers,
dropout=temporal_dropout,
).to(DEVICE)
temporal_path = BASE_DIR / "results" / "temporal.pt"
if not temporal_path.exists():
raise SystemExit(f"missing temporal model file: {temporal_path}")
temporal_model.load_state_dict(torch.load(str(temporal_path), map_location=DEVICE, weights_only=True))
temporal_model.eval()
betas = cosine_beta_schedule(timesteps).to(DEVICE) betas = cosine_beta_schedule(timesteps).to(DEVICE)
alphas = 1.0 - betas alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod = torch.cumprod(alphas, dim=0)
@@ -110,19 +135,26 @@ def main():
raise SystemExit("use_condition enabled but no files matched data_glob") raise SystemExit("use_condition enabled but no files matched data_glob")
cond = torch.randint(0, cond_vocab_size, (batch_size,), device=DEVICE, dtype=torch.long) cond = torch.randint(0, cond_vocab_size, (batch_size,), device=DEVICE, dtype=torch.long)
trend = None
if temporal_model is not None:
trend = temporal_model.generate(batch_size, seq_len, DEVICE)
for t in reversed(range(timesteps)): for t in reversed(range(timesteps)):
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long) t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch, cond) eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
# Continuous reverse step (DDPM): x_{t-1} mean
a_t = alphas[t]
a_bar_t = alphas_cumprod[t]
if cont_target == "x0": if cont_target == "x0":
x0_pred = eps_pred x0_pred = eps_pred
if cont_clamp_x0 > 0: if cont_clamp_x0 > 0:
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0) x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t) eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
elif cont_target == "v":
# Continuous reverse step (DDPM): x_{t-1} mean v_pred = eps_pred
a_t = alphas[t] x0_pred = torch.sqrt(a_bar_t) * x_cont - torch.sqrt(1.0 - a_bar_t) * v_pred
a_bar_t = alphas_cumprod[t] eps_pred = torch.sqrt(1.0 - a_bar_t) * x_cont + torch.sqrt(a_bar_t) * v_pred
coef1 = 1.0 / torch.sqrt(a_t) coef1 = 1.0 / torch.sqrt(a_t)
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t) coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
mean = coef1 * (x_cont - coef2 * eps_pred) mean = coef1 * (x_cont - coef2 * eps_pred)
@@ -146,6 +178,8 @@ def main():
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(BATCH_SIZE, SEQ_LEN) sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(BATCH_SIZE, SEQ_LEN)
x_disc[:, :, i][mask] = sampled[mask] x_disc[:, :, i][mask] = sampled[mask]
if trend is not None:
x_cont = x_cont + trend
print("sampled_cont_shape", tuple(x_cont.shape)) print("sampled_cont_shape", tuple(x_cont.shape))
print("sampled_disc_shape", tuple(x_disc.shape)) print("sampled_disc_shape", tuple(x_disc.shape))

View File

@@ -0,0 +1,40 @@
#!/usr/bin/env python3
"""Print average metrics from eval.json for quick tracking."""
import json
from datetime import datetime
from pathlib import Path
def mean(values):
return sum(values) / len(values) if values else None
def main():
base_dir = Path(__file__).resolve().parent
eval_path = base_dir / "results" / "eval.json"
if not eval_path.exists():
raise SystemExit(f"missing eval.json: {eval_path}")
obj = json.loads(eval_path.read_text(encoding="utf-8"))
ks = list(obj.get("continuous_ks", {}).values())
jsd = list(obj.get("discrete_jsd", {}).values())
lag = list(obj.get("continuous_lag1_diff", {}).values())
avg_ks = mean(ks)
avg_jsd = mean(jsd)
avg_lag1 = mean(lag)
print("avg_ks", avg_ks)
print("avg_jsd", avg_jsd)
print("avg_lag1_diff", avg_lag1)
history_path = base_dir / "results" / "metrics_history.csv"
if not history_path.exists():
history_path.write_text("timestamp,avg_ks,avg_jsd,avg_lag1_diff\n", encoding="utf-8")
with history_path.open("a", encoding="utf-8") as f:
f.write(f"{datetime.utcnow().isoformat()},{avg_ks},{avg_jsd},{avg_lag1}\n")
if __name__ == "__main__":
main()

View File

@@ -14,6 +14,7 @@ import torch.nn.functional as F
from data_utils import load_split, windowed_batches from data_utils import load_split, windowed_batches
from hybrid_diffusion import ( from hybrid_diffusion import (
HybridDiffusionModel, HybridDiffusionModel,
TemporalGRUGenerator,
cosine_beta_schedule, cosine_beta_schedule,
q_sample_continuous, q_sample_continuous,
q_sample_discrete, q_sample_discrete,
@@ -58,14 +59,25 @@ DEFAULTS = {
"model_ff_mult": 2, "model_ff_mult": 2,
"model_pos_dim": 64, "model_pos_dim": 64,
"model_use_pos_embed": True, "model_use_pos_embed": True,
"model_use_feature_graph": True,
"feature_graph_scale": 0.1,
"feature_graph_dropout": 0.0,
"use_temporal_stage1": True,
"temporal_hidden_dim": 256,
"temporal_num_layers": 1,
"temporal_dropout": 0.0,
"temporal_loss_weight": 1.0,
"disc_mask_scale": 0.9, "disc_mask_scale": 0.9,
"shuffle_buffer": 256, "shuffle_buffer": 256,
"cont_loss_weighting": "none", # none | inv_std "cont_loss_weighting": "none", # none | inv_std
"cont_loss_eps": 1e-6, "cont_loss_eps": 1e-6,
"cont_target": "eps", # eps | x0 "cont_target": "eps", # eps | x0 | v
"cont_clamp_x0": 0.0, "cont_clamp_x0": 0.0,
"quantile_loss_weight": 0.0, "quantile_loss_weight": 0.0,
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95], "quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
"quantile_loss_warmup_steps": 200,
"quantile_loss_clip": 6.0,
"quantile_loss_huber_delta": 1.0,
} }
@@ -190,12 +202,27 @@ def main():
ff_mult=int(config.get("model_ff_mult", 2)), ff_mult=int(config.get("model_ff_mult", 2)),
pos_dim=int(config.get("model_pos_dim", 64)), pos_dim=int(config.get("model_pos_dim", 64)),
use_pos_embed=bool(config.get("model_use_pos_embed", True)), use_pos_embed=bool(config.get("model_use_pos_embed", True)),
use_feature_graph=bool(config.get("model_use_feature_graph", False)),
feature_graph_scale=float(config.get("feature_graph_scale", 0.1)),
feature_graph_dropout=float(config.get("feature_graph_dropout", 0.0)),
cond_vocab_size=cond_vocab_size, cond_vocab_size=cond_vocab_size,
cond_dim=int(config.get("cond_dim", 32)), cond_dim=int(config.get("cond_dim", 32)),
use_tanh_eps=bool(config.get("use_tanh_eps", False)), use_tanh_eps=bool(config.get("use_tanh_eps", False)),
eps_scale=float(config.get("eps_scale", 1.0)), eps_scale=float(config.get("eps_scale", 1.0)),
).to(device) ).to(device)
temporal_model = None
if bool(config.get("use_temporal_stage1", False)):
temporal_model = TemporalGRUGenerator(
input_dim=len(cont_cols),
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
num_layers=int(config.get("temporal_num_layers", 1)),
dropout=float(config.get("temporal_dropout", 0.0)),
).to(device)
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"])) opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
if temporal_model is not None:
opt_temporal = torch.optim.Adam(temporal_model.parameters(), lr=float(config["lr"]))
else:
opt_temporal = None
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
betas = cosine_beta_schedule(int(config["timesteps"])).to(device) betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
@@ -205,8 +232,12 @@ def main():
os.makedirs(config["out_dir"], exist_ok=True) os.makedirs(config["out_dir"], exist_ok=True)
out_dir = safe_path(config["out_dir"]) out_dir = safe_path(config["out_dir"])
log_path = os.path.join(out_dir, "train_log.csv") log_path = os.path.join(out_dir, "train_log.csv")
use_quantile = float(config.get("quantile_loss_weight", 0.0)) > 0
with open(log_path, "w", encoding="utf-8") as f: with open(log_path, "w", encoding="utf-8") as f:
f.write("epoch,step,loss,loss_cont,loss_disc\n") if use_quantile:
f.write("epoch,step,loss,loss_cont,loss_disc,loss_quantile\n")
else:
f.write("epoch,step,loss,loss_cont,loss_disc\n")
with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f: with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f:
json.dump(config, f, indent=2) json.dump(config, f, indent=2)
@@ -237,10 +268,20 @@ def main():
x_cont = x_cont.to(device) x_cont = x_cont.to(device)
x_disc = x_disc.to(device) x_disc = x_disc.to(device)
temporal_loss = None
x_cont_resid = x_cont
trend = None
if temporal_model is not None:
trend, pred_next = temporal_model.forward_teacher(x_cont)
temporal_loss = F.mse_loss(pred_next, x_cont[:, 1:, :])
temporal_loss = temporal_loss * float(config.get("temporal_loss_weight", 1.0))
trend = trend.detach()
x_cont_resid = x_cont - trend
bsz = x_cont.size(0) bsz = x_cont.size(0)
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device) t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod) x_cont_t, noise = q_sample_continuous(x_cont_resid, t, alphas_cumprod)
mask_tokens = torch.tensor(vocab_sizes, device=device) mask_tokens = torch.tensor(vocab_sizes, device=device)
x_disc_t, mask = q_sample_discrete( x_disc_t, mask = q_sample_discrete(
@@ -255,10 +296,14 @@ def main():
cont_target = str(config.get("cont_target", "eps")) cont_target = str(config.get("cont_target", "eps"))
if cont_target == "x0": if cont_target == "x0":
x0_target = x_cont x0_target = x_cont_resid
if float(config.get("cont_clamp_x0", 0.0)) > 0: if float(config.get("cont_clamp_x0", 0.0)) > 0:
x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"])) x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"]))
loss_base = (eps_pred - x0_target) ** 2 loss_base = (eps_pred - x0_target) ** 2
elif cont_target == "v":
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
v_target = torch.sqrt(a_bar_t) * noise - torch.sqrt(1.0 - a_bar_t) * x_cont_resid
loss_base = (eps_pred - v_target) ** 2
else: else:
loss_base = (eps_pred - noise) ** 2 loss_base = (eps_pred - noise) ** 2
@@ -286,36 +331,74 @@ def main():
loss = lam * loss_cont + (1 - lam) * loss_disc loss = lam * loss_cont + (1 - lam) * loss_disc
q_weight = float(config.get("quantile_loss_weight", 0.0)) q_weight = float(config.get("quantile_loss_weight", 0.0))
quantile_loss = 0.0
if q_weight > 0: if q_weight > 0:
warmup = int(config.get("quantile_loss_warmup_steps", 0))
if warmup > 0:
q_weight = q_weight * min(1.0, (total_step + 1) / float(warmup))
q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95]) q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95])
q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype) q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype)
# Use normalized space for stable quantiles. # Use normalized space for stable quantiles on x0.
x_real = x_cont x_real = x_cont_resid
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
if cont_target == "x0": if cont_target == "x0":
x_gen = eps_pred x_gen = eps_pred
elif cont_target == "v":
v_pred = eps_pred
x_gen = torch.sqrt(a_bar_t) * x_cont_t - torch.sqrt(1.0 - a_bar_t) * v_pred
else: else:
x_gen = x_cont - noise # eps prediction
x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t)
q_clip = float(config.get("quantile_loss_clip", 0.0))
if q_clip > 0:
x_real = torch.clamp(x_real, -q_clip, q_clip)
x_gen = torch.clamp(x_gen, -q_clip, q_clip)
x_real = x_real.view(-1, x_real.size(-1)) x_real = x_real.view(-1, x_real.size(-1))
x_gen = x_gen.view(-1, x_gen.size(-1)) x_gen = x_gen.view(-1, x_gen.size(-1))
q_real = torch.quantile(x_real, q_tensor, dim=0) q_real = torch.quantile(x_real, q_tensor, dim=0)
q_gen = torch.quantile(x_gen, q_tensor, dim=0) q_gen = torch.quantile(x_gen, q_tensor, dim=0)
quantile_loss = torch.mean(torch.abs(q_gen - q_real)) q_delta = float(config.get("quantile_loss_huber_delta", 0.0))
q_diff = q_gen - q_real
if q_delta > 0:
quantile_loss = torch.nn.functional.smooth_l1_loss(q_gen, q_real, beta=q_delta)
else:
quantile_loss = torch.mean(torch.abs(q_diff))
loss = loss + q_weight * quantile_loss loss = loss + q_weight * quantile_loss
opt.zero_grad() opt.zero_grad()
loss.backward() loss.backward()
if float(config.get("grad_clip", 0.0)) > 0: if float(config.get("grad_clip", 0.0)) > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"])) torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"]))
opt.step() opt.step()
if opt_temporal is not None:
opt_temporal.zero_grad()
temporal_loss.backward()
if float(config.get("grad_clip", 0.0)) > 0:
torch.nn.utils.clip_grad_norm_(temporal_model.parameters(), float(config["grad_clip"]))
opt_temporal.step()
if ema is not None: if ema is not None:
ema.update(model) ema.update(model)
if step % int(config["log_every"]) == 0: if step % int(config["log_every"]) == 0:
print("epoch", epoch, "step", step, "loss", float(loss)) print("epoch", epoch, "step", step, "loss", float(loss))
with open(log_path, "a", encoding="utf-8") as f: with open(log_path, "a", encoding="utf-8") as f:
f.write( if use_quantile:
"%d,%d,%.6f,%.6f,%.6f\n" f.write(
% (epoch, step, float(loss), float(loss_cont), float(loss_disc)) "%d,%d,%.6f,%.6f,%.6f,%.6f\n"
) % (
epoch,
step,
float(loss),
float(loss_cont),
float(loss_disc),
float(quantile_loss),
)
)
else:
f.write(
"%d,%d,%.6f,%.6f,%.6f\n"
% (epoch, step, float(loss), float(loss_cont), float(loss_disc))
)
total_step += 1 total_step += 1
if total_step % int(config["ckpt_every"]) == 0: if total_step % int(config["ckpt_every"]) == 0:
@@ -327,11 +410,15 @@ def main():
} }
if ema is not None: if ema is not None:
ckpt["ema"] = ema.state_dict() ckpt["ema"] = ema.state_dict()
if temporal_model is not None:
ckpt["temporal"] = temporal_model.state_dict()
torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt")) torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt"))
torch.save(model.state_dict(), os.path.join(out_dir, "model.pt")) torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
if ema is not None: if ema is not None:
torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt")) torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt"))
if temporal_model is not None:
torch.save(temporal_model.state_dict(), os.path.join(out_dir, "temporal.pt"))
if __name__ == "__main__": if __name__ == "__main__":

365
report.md Normal file
View File

@@ -0,0 +1,365 @@
# Hybrid Diffusion for ICS Traffic (HAI 21.03) — Project Report
# 工业控制系统流量混合扩散生成HAI 21.03)— 项目报告
## 1. Project Goal / 项目目标
Build a **hybrid diffusion-based generator** for industrial control system (ICS) traffic features, targeting **mixed continuous + discrete** feature sequences. The output is **feature-level sequences**, not raw packets. The generator should preserve:
- **Distributional fidelity** (continuous value ranges and discrete frequencies)
- **Temporal consistency** (time correlation and sequence structure)
- **Protocol/field consistency** (for discrete fields)
构建一个用于工业控制系统ICS流量特征的**混合扩散生成模型**,面向**连续+离散混合特征序列**。输出为**特征级序列**而非原始报文。生成结果需要同时保持:
- **分布一致性**(连续值范围与离散取值频率)
- **时序一致性**(时间相关性与序列结构)
- **字段/协议一致性**(离散字段的逻辑一致)
This project is aligned with the STOUTER idea of **structure-aware diffusion** for spatiotemporal data, but applied to **ICS feature sequences** rather than cellular traffic.
本项目呼应 STOUTER 的**结构先验+扩散**思想,但应用于**ICS 特征序列**而非蜂窝流量。
---
## 2. Data and Scope / 数据与范围
**Dataset used in the current implementation:** HAI 21.03 (CSV feature traces)
**当前实现使用的数据集:** HAI 21.03CSV 特征序列)
**Data location (default in config):**
- `dataset/hai/hai-21.03/train*.csv.gz`
**数据位置config 默认):**
- `dataset/hai/hai-21.03/train*.csv.gz`
**Feature split (fixed schema):**
- Defined in `example/feature_split.json`
- **Continuous features:** sensor/process values
- **Discrete features:** binary/low-cardinality status/flag fields
- `time` column is excluded from modeling
**特征拆分(固定 schema**
- `example/feature_split.json`
- **连续特征:** 传感器/过程值
- **离散特征:** 二值/低基数状态字段
- `time` 列不参与训练
---
## 3. End-to-End Pipeline / 端到端流程
**One command pipeline:**
```
python example/run_all.py --device cuda
```
**一键流程:**
```
python example/run_all.py --device cuda
```
### Pipeline stages / 流程阶段
1) **Prepare data** (`example/prepare_data.py`)
2) **Train model** (`example/train.py`)
3) **Generate samples** (`example/export_samples.py`)
4) **Evaluate** (`example/evaluate_generated.py`)
5) **Summarize metrics** (`example/summary_metrics.py`)
1) **数据准备**(统计量与词表)
2) **训练模型**
3) **生成样本并导出**
4) **评估指标**
5) **汇总指标**
---
## 4. Technical Architecture / 技术架构
### 4.1 Hybrid Diffusion Model (Core) / 混合扩散模型(核心)
Defined in `example/hybrid_diffusion.py`.
**Key components:**
- **Continuous branch**: Gaussian diffusion (DDPM style)
- **Discrete branch**: Mask diffusion for categorical tokens
- **Shared backbone**: GRU + residual MLP + LayerNorm
- **Embedding inputs**:
- continuous projection
- discrete embeddings per column
- time embedding (sinusoidal)
- positional embedding (sequence index)
- optional condition embedding (`file_id`)
**Outputs:**
- Continuous head: predicts target (`eps`, `x0`, or `v`)
- Discrete heads: predict logits for each discrete column
**核心组成:**
- **连续分支:** 高斯扩散DDPM
- **离散分支:** Mask 扩散
- **共享主干:** GRU + 残差 MLP + LayerNorm
- **输入嵌入:**
- 连续投影
- 离散字段嵌入
- 时间嵌入(正弦)
- 位置嵌入(序列索引)
- 条件嵌入(可选,`file_id`
**输出:**
- 连续 head预测 `eps/x0/v`
- 离散 head各字段 logits
---
### 4.2 Feature Graph Mixer (Structure Prior) / 特征图混合器(结构先验)
Implemented in `example/hybrid_diffusion.py` as `FeatureGraphMixer`.
Purpose: inject **learnable feature-dependency prior** without dataset-specific hardcoding.
**Mechanism:**
- Learns a dense feature relation matrix `A`
- Applies: `x + x @ A`
- Symmetric stabilizing constraint: `(A + A^T)/2`
- Controlled by scale and dropout
**Config:**
```
"model_use_feature_graph": true,
"feature_graph_scale": 0.1,
"feature_graph_dropout": 0.0
```
**目的:**在不写死特定数据集关系的情况下,引入**可学习特征依赖先验**
**机制:**
- 学习稠密关系矩阵 `A`
- 特征混合:`x + x @ A`
- 对称化稳定:`(A + A^T)/2`
- 通过 scale/dropout 控制强度
---
### 4.3 Two-Stage Temporal Backbone / 两阶段时序骨干
Stage-1 uses a **GRU temporal generator** to model sequence trend in normalized space. Stage-2 diffusion then models the **residual** (x trend). This decouples temporal consistency from distribution alignment.
第一阶段使用 **GRU 时序生成器**在归一化空间建模序列趋势;第二阶段扩散模型学习**残差**x trend实现时序一致性与分布对齐的解耦。
---
## 5. Diffusion Formulations / 扩散建模形式
### 5.1 Continuous Diffusion / 连续扩散
Forward process:
```
x_t = sqrt(a_bar_t) * x_0 + sqrt(1 - a_bar_t) * eps
```
Targets supported:
- **eps prediction** (standard DDPM)
- **x0 prediction** (direct reconstruction)
- **v prediction** (v = sqrt(a_bar)*eps sqrt(1-a_bar)*x0)
Current config default:
```
"cont_target": "v"
```
Sampling uses the target to reconstruct `eps` and apply standard DDPM reverse update.
**前向扩散:**如上公式。
**支持的目标:**
- `eps`(噪声预测)
- `x0`(原样本预测)
- `v`vprediction
**当前默认:**`cont_target = v`
**采样:**根据目标反解 `eps` 再执行标准 DDPM 反向步骤。
---
### 5.2 Discrete Diffusion (Mask) / 离散扩散Mask
Forward process: replace tokens with `[MASK]` using cosine schedule:
```
p(t) = 0.5 * (1 - cos(pi * t / T))
```
Optional scale: `disc_mask_scale`
Reverse process: cross-entropy on masked positions only.
**前向:**按 cosine schedule 进行 Mask。
**反向:**仅在 mask 位置计算交叉熵。
---
## 6. Loss Design (Current) / 当前损失设计
Total loss:
```
L = λ * L_cont + (1 λ) * L_disc + w_q * L_quantile
```
### 6.1 Continuous Loss / 连续损失
Depending on `cont_target`:
- eps target: MSE(eps_pred, eps)
- x0 target: MSE(x0_pred, x0)
- v target: MSE(v_pred, v_target)
Optional inverse-variance weighting:
```
cont_loss_weighting = "inv_std"
```
### 6.2 Discrete Loss / 离散损失
Cross-entropy on masked positions only.
### 6.3 Quantile Loss (Distribution Alignment) / 分位数损失(分布对齐)
Added to improve KS (distribution shape alignment):
- Compute quantiles on generated vs real x0
- Loss = Huber or L1 difference on quantiles
Stabilization:
```
quantile_loss_warmup_steps
quantile_loss_clip
quantile_loss_huber_delta
```
---
## 7. Training Strategy / 训练策略
Defined in `example/train.py`.
**Key techniques:**
- EMA of model weights
- Gradient clipping
- Shuffle buffer to reduce batch bias
- Optional feature graph prior
- Quantile loss warmup for stability
- Optional stage-1 temporal GRU (trend) + residual diffusion
**Config highlights (example/config.json):**
```
timesteps: 600
batch_size: 128
seq_len: 128
epochs: 10
max_batches: 4000
lambda: 0.7
cont_target: "v"
quantile_loss_weight: 0.1
model_use_feature_graph: true
use_temporal_stage1: true
```
---
## 8. Sampling & Export / 采样与导出
Defined in:
- `example/sample.py`
- `example/export_samples.py`
**Export steps:**
- Reverse diffusion with conditional sampling
- Reverse normalize continuous values
- Clamp to observed min/max
- Restore discrete tokens from vocab
- Write to CSV
---
## 9. Evaluation Metrics / 评估指标
Implemented in `example/evaluate_generated.py`.
### Continuous Metrics / 连续指标
- **KS statistic** (distribution similarity per feature)
- **Quantile errors** (q05/q25/q50/q75/q95)
- **Lag1 correlation diff** (temporal structure)
### Discrete Metrics / 离散指标
- **JSD** over token frequency distribution
- **Invalid token counts**
### Summary Metrics / 汇总指标
Auto-logged in:
- `example/results/metrics_history.csv`
- via `example/summary_metrics.py`
---
## 10. Automation / 自动化
### Oneclick pipeline / 一键流程
```
python example/run_all.py --device cuda
```
### Metrics logging / 指标记录
Each run appends:
```
timestamp,avg_ks,avg_jsd,avg_lag1_diff
```
---
## 11. Key Engineering Decisions / 关键工程决策
### 11.1 Mixed-Type Diffusion / 混合类型扩散
Continuous + discrete handled separately to respect data types.
### 11.2 Structure Prior / 结构先验
Learnable feature graph added to encode implicit dependencies.
### 11.3 vprediction
Chosen to stabilize training and improve convergence in diffusion.
### 11.4 Distribution Alignment / 分布对齐
Quantile loss introduced to directly reduce KS.
---
## 12. Known Issues / Current Limitations / 已知问题与当前局限
- **KS remains high** in many experiments, meaning continuous distributions are still misaligned.
- **Lag1 may degrade** when quantile loss is too strong.
- **Loss spikes** observed when quantile loss is unstable (mitigated with warmup + clip + Huber).
**当前问题:**
- KS 高,说明连续分布仍未对齐
- 分位数损失过强时会损害时序相关性
- 分位数损失不稳定时会出现 loss 爆炸(已引入 warmup/clip/Huber
---
## 13. Suggested Next Steps (Research Roadmap) / 下一步建议(研究路线)
1) **SNR-weighted loss** (improve stability across timesteps)
2) **Two-stage training** (distribution first, temporal consistency second)
3) **Upgrade discrete diffusion** (D3PM-style transitions)
4) **Structured conditioning** (state/phase conditioning)
5) **Graph-based priors** (explicit feature/plant dependency graphs)
---
## 14. Code Map (Key Files) / 代码索引(关键文件)
**Core model**
- `example/hybrid_diffusion.py`
**Training**
- `example/train.py`
**Sampling & export**
- `example/sample.py`
- `example/export_samples.py`
**Pipeline**
- `example/run_all.py`
**Evaluation**
- `example/evaluate_generated.py`
- `example/summary_metrics.py`
**Configs**
- `example/config.json`
---
## 15. Summary / 总结
This project implements a **hybrid diffusion model for ICS traffic features**, combining continuous Gaussian diffusion with discrete mask diffusion, enhanced with a **learnable feature-graph prior**. The system includes a full pipeline for preparation, training, sampling, exporting, and evaluation. Key research challenges remain in **distribution alignment (KS)** and **joint optimization of distribution fidelity vs temporal consistency**, motivating future improvements such as SNR-weighted loss, staged training, and stronger structural priors.
本项目实现了用于 ICS 流量特征的**混合扩散模型**,将连续高斯扩散与离散 Mask 扩散结合,并引入**可学习特征图先验**。系统包含完整的数据准备、训练、采样、导出与评估流程。当前研究挑战集中在**连续分布对齐KS**与**分布/时序一致性之间的权衡**,后续可通过 SNRweighted loss、分阶段训练与更强结构先验继续改进。