Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| dd4c1e171f | |||
|
|
f8edee9510 | ||
| cb610281ce | |||
| bc838d7cd7 | |||
|
|
b3c45010a4 | ||
| 2a1a9a05c6 | |||
| cc10125fbf | |||
|
|
5ef1e465f9 | ||
| 5aba14d511 | |||
|
|
bdfc6e2aaa | ||
|
|
666bc3b8a9 | ||
| a870945e33 |
@@ -72,3 +72,5 @@ python example/run_pipeline.py --device auto
|
|||||||
- The script only samples the first 5000 rows to stay fast.
|
- The script only samples the first 5000 rows to stay fast.
|
||||||
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.
|
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.
|
||||||
- `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU.
|
- `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU.
|
||||||
|
- Optional feature-graph mixer (`model_use_feature_graph`) adds a learnable relation prior across feature channels.
|
||||||
|
- Optional two-stage temporal backbone (`use_temporal_stage1`) learns a GRU-based sequence trend; diffusion models the residual.
|
||||||
|
|||||||
@@ -32,11 +32,24 @@
|
|||||||
"model_ff_mult": 2,
|
"model_ff_mult": 2,
|
||||||
"model_pos_dim": 64,
|
"model_pos_dim": 64,
|
||||||
"model_use_pos_embed": true,
|
"model_use_pos_embed": true,
|
||||||
|
"model_use_feature_graph": true,
|
||||||
|
"feature_graph_scale": 0.1,
|
||||||
|
"feature_graph_dropout": 0.0,
|
||||||
|
"use_temporal_stage1": true,
|
||||||
|
"temporal_hidden_dim": 256,
|
||||||
|
"temporal_num_layers": 1,
|
||||||
|
"temporal_dropout": 0.0,
|
||||||
|
"temporal_loss_weight": 1.0,
|
||||||
"disc_mask_scale": 0.9,
|
"disc_mask_scale": 0.9,
|
||||||
"cont_loss_weighting": "inv_std",
|
"cont_loss_weighting": "inv_std",
|
||||||
"cont_loss_eps": 1e-6,
|
"cont_loss_eps": 1e-6,
|
||||||
"cont_target": "x0",
|
"cont_target": "v",
|
||||||
"cont_clamp_x0": 5.0,
|
"cont_clamp_x0": 5.0,
|
||||||
|
"quantile_loss_weight": 0.1,
|
||||||
|
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
|
||||||
|
"quantile_loss_warmup_steps": 200,
|
||||||
|
"quantile_loss_clip": 6.0,
|
||||||
|
"quantile_loss_huber_delta": 1.0,
|
||||||
"shuffle_buffer": 256,
|
"shuffle_buffer": 256,
|
||||||
"sample_batch_size": 8,
|
"sample_batch_size": 8,
|
||||||
"sample_seq_len": 128
|
"sample_seq_len": 128
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from data_utils import load_split
|
from data_utils import load_split
|
||||||
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
|
from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
|
||||||
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
|
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
|
||||||
|
|
||||||
|
|
||||||
@@ -140,6 +140,10 @@ def main():
|
|||||||
raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob)
|
raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob)
|
||||||
cont_target = str(cfg.get("cont_target", "eps"))
|
cont_target = str(cfg.get("cont_target", "eps"))
|
||||||
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
|
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
|
||||||
|
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
|
||||||
|
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
|
||||||
|
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
|
||||||
|
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
|
||||||
|
|
||||||
model = HybridDiffusionModel(
|
model = HybridDiffusionModel(
|
||||||
cont_dim=len(cont_cols),
|
cont_dim=len(cont_cols),
|
||||||
@@ -151,6 +155,9 @@ def main():
|
|||||||
ff_mult=int(cfg.get("model_ff_mult", 2)),
|
ff_mult=int(cfg.get("model_ff_mult", 2)),
|
||||||
pos_dim=int(cfg.get("model_pos_dim", 64)),
|
pos_dim=int(cfg.get("model_pos_dim", 64)),
|
||||||
use_pos_embed=bool(cfg.get("model_use_pos_embed", True)),
|
use_pos_embed=bool(cfg.get("model_use_pos_embed", True)),
|
||||||
|
use_feature_graph=bool(cfg.get("model_use_feature_graph", False)),
|
||||||
|
feature_graph_scale=float(cfg.get("feature_graph_scale", 0.1)),
|
||||||
|
feature_graph_dropout=float(cfg.get("feature_graph_dropout", 0.0)),
|
||||||
cond_vocab_size=cond_vocab_size if use_condition else 0,
|
cond_vocab_size=cond_vocab_size if use_condition else 0,
|
||||||
cond_dim=int(cfg.get("cond_dim", 32)),
|
cond_dim=int(cfg.get("cond_dim", 32)),
|
||||||
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
|
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
|
||||||
@@ -163,6 +170,20 @@ def main():
|
|||||||
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
|
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
temporal_model = None
|
||||||
|
if use_temporal_stage1:
|
||||||
|
temporal_model = TemporalGRUGenerator(
|
||||||
|
input_dim=len(cont_cols),
|
||||||
|
hidden_dim=temporal_hidden_dim,
|
||||||
|
num_layers=temporal_num_layers,
|
||||||
|
dropout=temporal_dropout,
|
||||||
|
).to(device)
|
||||||
|
temporal_path = Path(args.model_path).with_name("temporal.pt")
|
||||||
|
if not temporal_path.exists():
|
||||||
|
raise SystemExit(f"missing temporal model file: {temporal_path}")
|
||||||
|
temporal_model.load_state_dict(torch.load(temporal_path, map_location=device, weights_only=True))
|
||||||
|
temporal_model.eval()
|
||||||
|
|
||||||
betas = cosine_beta_schedule(args.timesteps).to(device)
|
betas = cosine_beta_schedule(args.timesteps).to(device)
|
||||||
alphas = 1.0 - betas
|
alphas = 1.0 - betas
|
||||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||||
@@ -189,6 +210,10 @@ def main():
|
|||||||
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
|
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
|
||||||
cond = cond_id
|
cond = cond_id
|
||||||
|
|
||||||
|
trend = None
|
||||||
|
if temporal_model is not None:
|
||||||
|
trend = temporal_model.generate(args.batch_size, args.seq_len, device)
|
||||||
|
|
||||||
for t in reversed(range(args.timesteps)):
|
for t in reversed(range(args.timesteps)):
|
||||||
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
|
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
|
||||||
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
||||||
@@ -201,6 +226,10 @@ def main():
|
|||||||
if cont_clamp_x0 > 0:
|
if cont_clamp_x0 > 0:
|
||||||
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
|
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
|
||||||
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
|
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
|
||||||
|
elif cont_target == "v":
|
||||||
|
v_pred = eps_pred
|
||||||
|
x0_pred = torch.sqrt(a_bar_t) * x_cont - torch.sqrt(1.0 - a_bar_t) * v_pred
|
||||||
|
eps_pred = torch.sqrt(1.0 - a_bar_t) * x_cont + torch.sqrt(a_bar_t) * v_pred
|
||||||
coef1 = 1.0 / torch.sqrt(a_t)
|
coef1 = 1.0 / torch.sqrt(a_t)
|
||||||
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
|
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
|
||||||
mean_x = coef1 * (x_cont - coef2 * eps_pred)
|
mean_x = coef1 * (x_cont - coef2 * eps_pred)
|
||||||
@@ -225,6 +254,8 @@ def main():
|
|||||||
)
|
)
|
||||||
x_disc[:, :, i][mask] = sampled[mask]
|
x_disc[:, :, i][mask] = sampled[mask]
|
||||||
|
|
||||||
|
if trend is not None:
|
||||||
|
x_cont = x_cont + trend
|
||||||
# move to CPU for export
|
# move to CPU for export
|
||||||
x_cont = x_cont.cpu()
|
x_cont = x_cont.cpu()
|
||||||
x_disc = x_disc.cpu()
|
x_disc = x_disc.cpu()
|
||||||
|
|||||||
@@ -66,6 +66,64 @@ class SinusoidalTimeEmbedding(nn.Module):
|
|||||||
return emb
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureGraphMixer(nn.Module):
|
||||||
|
"""Learnable feature relation mixer (dataset-agnostic)."""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, scale: float = 0.1, dropout: float = 0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = scale
|
||||||
|
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||||
|
self.A = nn.Parameter(torch.zeros(dim, dim))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# x: (B, T, D)
|
||||||
|
# Symmetric relation to stabilize
|
||||||
|
A = (self.A + self.A.t()) * 0.5
|
||||||
|
mixed = torch.matmul(x, A) * self.scale
|
||||||
|
mixed = self.dropout(mixed)
|
||||||
|
return x + mixed
|
||||||
|
|
||||||
|
|
||||||
|
class TemporalGRUGenerator(nn.Module):
|
||||||
|
"""Stage-1 temporal generator (autoregressive GRU) for sequence backbone."""
|
||||||
|
|
||||||
|
def __init__(self, input_dim: int, hidden_dim: int = 256, num_layers: int = 1, dropout: float = 0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.start_token = nn.Parameter(torch.zeros(input_dim))
|
||||||
|
self.gru = nn.GRU(
|
||||||
|
input_dim,
|
||||||
|
hidden_dim,
|
||||||
|
num_layers=num_layers,
|
||||||
|
dropout=dropout if num_layers > 1 else 0.0,
|
||||||
|
batch_first=True,
|
||||||
|
)
|
||||||
|
self.out = nn.Linear(hidden_dim, input_dim)
|
||||||
|
|
||||||
|
def forward_teacher(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Teacher-forced next-step prediction. Returns trend sequence and preds."""
|
||||||
|
if x.size(1) < 2:
|
||||||
|
raise ValueError("sequence length must be >= 2 for teacher forcing")
|
||||||
|
inp = x[:, :-1, :]
|
||||||
|
out, _ = self.gru(inp)
|
||||||
|
pred_next = self.out(out)
|
||||||
|
trend = torch.zeros_like(x)
|
||||||
|
trend[:, 0, :] = x[:, 0, :]
|
||||||
|
trend[:, 1:, :] = pred_next
|
||||||
|
return trend, pred_next
|
||||||
|
|
||||||
|
def generate(self, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
|
||||||
|
"""Autoregressively generate a backbone sequence."""
|
||||||
|
h = None
|
||||||
|
prev = self.start_token.unsqueeze(0).expand(batch_size, -1).to(device)
|
||||||
|
outputs = []
|
||||||
|
for _ in range(seq_len):
|
||||||
|
out, h = self.gru(prev.unsqueeze(1), h)
|
||||||
|
nxt = self.out(out.squeeze(1))
|
||||||
|
outputs.append(nxt.unsqueeze(1))
|
||||||
|
prev = nxt
|
||||||
|
return torch.cat(outputs, dim=1)
|
||||||
|
|
||||||
|
|
||||||
class HybridDiffusionModel(nn.Module):
|
class HybridDiffusionModel(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -78,6 +136,9 @@ class HybridDiffusionModel(nn.Module):
|
|||||||
ff_mult: int = 2,
|
ff_mult: int = 2,
|
||||||
pos_dim: int = 64,
|
pos_dim: int = 64,
|
||||||
use_pos_embed: bool = True,
|
use_pos_embed: bool = True,
|
||||||
|
use_feature_graph: bool = False,
|
||||||
|
feature_graph_scale: float = 0.1,
|
||||||
|
feature_graph_dropout: float = 0.0,
|
||||||
cond_vocab_size: int = 0,
|
cond_vocab_size: int = 0,
|
||||||
cond_dim: int = 32,
|
cond_dim: int = 32,
|
||||||
use_tanh_eps: bool = False,
|
use_tanh_eps: bool = False,
|
||||||
@@ -92,6 +153,7 @@ class HybridDiffusionModel(nn.Module):
|
|||||||
self.eps_scale = eps_scale
|
self.eps_scale = eps_scale
|
||||||
self.pos_dim = pos_dim
|
self.pos_dim = pos_dim
|
||||||
self.use_pos_embed = use_pos_embed
|
self.use_pos_embed = use_pos_embed
|
||||||
|
self.use_feature_graph = use_feature_graph
|
||||||
|
|
||||||
self.cond_vocab_size = cond_vocab_size
|
self.cond_vocab_size = cond_vocab_size
|
||||||
self.cond_dim = cond_dim
|
self.cond_dim = cond_dim
|
||||||
@@ -106,8 +168,17 @@ class HybridDiffusionModel(nn.Module):
|
|||||||
disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds)
|
disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds)
|
||||||
|
|
||||||
self.cont_proj = nn.Linear(cont_dim, cont_dim)
|
self.cont_proj = nn.Linear(cont_dim, cont_dim)
|
||||||
|
self.feature_dim = cont_dim + disc_embed_dim
|
||||||
|
if use_feature_graph:
|
||||||
|
self.feature_graph = FeatureGraphMixer(
|
||||||
|
self.feature_dim,
|
||||||
|
scale=feature_graph_scale,
|
||||||
|
dropout=feature_graph_dropout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.feature_graph = None
|
||||||
pos_dim = pos_dim if use_pos_embed else 0
|
pos_dim = pos_dim if use_pos_embed else 0
|
||||||
in_dim = cont_dim + disc_embed_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0)
|
in_dim = self.feature_dim + time_dim + pos_dim + (cond_dim if self.cond_embed is not None else 0)
|
||||||
self.in_proj = nn.Linear(in_dim, hidden_dim)
|
self.in_proj = nn.Linear(in_dim, hidden_dim)
|
||||||
self.backbone = nn.GRU(
|
self.backbone = nn.GRU(
|
||||||
hidden_dim,
|
hidden_dim,
|
||||||
@@ -149,7 +220,10 @@ class HybridDiffusionModel(nn.Module):
|
|||||||
cond_feat = self.cond_embed(cond).unsqueeze(1).expand(-1, x_cont.size(1), -1)
|
cond_feat = self.cond_embed(cond).unsqueeze(1).expand(-1, x_cont.size(1), -1)
|
||||||
|
|
||||||
cont_feat = self.cont_proj(x_cont)
|
cont_feat = self.cont_proj(x_cont)
|
||||||
parts = [cont_feat, disc_feat, time_emb]
|
feat = torch.cat([cont_feat, disc_feat], dim=-1)
|
||||||
|
if self.feature_graph is not None:
|
||||||
|
feat = self.feature_graph(feat)
|
||||||
|
parts = [feat, time_emb]
|
||||||
if pos_emb is not None:
|
if pos_emb is not None:
|
||||||
parts.append(pos_emb.unsqueeze(0).expand(x_cont.size(0), -1, -1))
|
parts.append(pos_emb.unsqueeze(0).expand(x_cont.size(0), -1, -1))
|
||||||
if cond_feat is not None:
|
if cond_feat is not None:
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ def main():
|
|||||||
loss = []
|
loss = []
|
||||||
loss_cont = []
|
loss_cont = []
|
||||||
loss_disc = []
|
loss_disc = []
|
||||||
|
loss_quant = []
|
||||||
|
|
||||||
with log_path.open("r", encoding="utf-8", newline="") as f:
|
with log_path.open("r", encoding="utf-8", newline="") as f:
|
||||||
reader = csv.DictReader(f)
|
reader = csv.DictReader(f)
|
||||||
@@ -42,6 +43,8 @@ def main():
|
|||||||
loss.append(float(row["loss"]))
|
loss.append(float(row["loss"]))
|
||||||
loss_cont.append(float(row["loss_cont"]))
|
loss_cont.append(float(row["loss_cont"]))
|
||||||
loss_disc.append(float(row["loss_disc"]))
|
loss_disc.append(float(row["loss_disc"]))
|
||||||
|
if "loss_quantile" in row:
|
||||||
|
loss_quant.append(float(row["loss_quantile"]))
|
||||||
|
|
||||||
if not steps:
|
if not steps:
|
||||||
raise SystemExit("no rows in log file: %s" % log_path)
|
raise SystemExit("no rows in log file: %s" % log_path)
|
||||||
@@ -50,6 +53,8 @@ def main():
|
|||||||
plt.plot(steps, loss, label="total")
|
plt.plot(steps, loss, label="total")
|
||||||
plt.plot(steps, loss_cont, label="continuous")
|
plt.plot(steps, loss_cont, label="continuous")
|
||||||
plt.plot(steps, loss_disc, label="discrete")
|
plt.plot(steps, loss_disc, label="discrete")
|
||||||
|
if loss_quant:
|
||||||
|
plt.plot(steps, loss_quant, label="quantile")
|
||||||
plt.xlabel("step")
|
plt.xlabel("step")
|
||||||
plt.ylabel("loss")
|
plt.ylabel("loss")
|
||||||
plt.title("Training Loss")
|
plt.title("Training Loss")
|
||||||
|
|||||||
@@ -32,12 +32,26 @@
|
|||||||
"model_ff_mult": 2,
|
"model_ff_mult": 2,
|
||||||
"model_pos_dim": 64,
|
"model_pos_dim": 64,
|
||||||
"model_use_pos_embed": true,
|
"model_use_pos_embed": true,
|
||||||
|
"model_use_feature_graph": true,
|
||||||
|
"feature_graph_scale": 0.1,
|
||||||
|
"feature_graph_dropout": 0.0,
|
||||||
"disc_mask_scale": 0.9,
|
"disc_mask_scale": 0.9,
|
||||||
"shuffle_buffer": 256,
|
"shuffle_buffer": 256,
|
||||||
"cont_loss_weighting": "inv_std",
|
"cont_loss_weighting": "inv_std",
|
||||||
"cont_loss_eps": 1e-06,
|
"cont_loss_eps": 1e-06,
|
||||||
"cont_target": "x0",
|
"cont_target": "v",
|
||||||
"cont_clamp_x0": 5.0,
|
"cont_clamp_x0": 5.0,
|
||||||
|
"quantile_loss_weight": 0.1,
|
||||||
|
"quantile_points": [
|
||||||
|
0.05,
|
||||||
|
0.25,
|
||||||
|
0.5,
|
||||||
|
0.75,
|
||||||
|
0.95
|
||||||
|
],
|
||||||
|
"quantile_loss_warmup_steps": 200,
|
||||||
|
"quantile_loss_clip": 6.0,
|
||||||
|
"quantile_loss_huber_delta": 1.0,
|
||||||
"sample_batch_size": 8,
|
"sample_batch_size": 8,
|
||||||
"sample_seq_len": 128
|
"sample_seq_len": 128
|
||||||
}
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,61 +1,61 @@
|
|||||||
epoch,step,loss,loss_cont,loss_disc
|
epoch,step,loss,loss_cont,loss_disc,loss_quantile
|
||||||
0,0,9801.450195,14001.333008,1.724242
|
0,0,11527.750000,16467.474609,1.724242,0.113109
|
||||||
0,10,8648.388672,12354.390625,1.050415
|
0,10,8955.048828,12792.469727,1.065955,0.142707
|
||||||
0,20,6285.666992,8979.191406,0.775979
|
0,20,123049.507812,175784.562500,1.048191,0.181546
|
||||||
0,30,6296.939941,8995.372070,0.598658
|
0,30,30947.671875,44210.511719,1.034107,0.224363
|
||||||
0,40,7126.128906,10180.011719,0.402291
|
0,40,11166.339844,15951.691406,0.512211,0.163978
|
||||||
0,50,5381.071289,7687.099121,0.340463
|
0,50,14919.276367,21313.089844,0.369016,0.101041
|
||||||
1,0,17876.904297,25538.212891,0.522857
|
1,0,1113425.000000,1590606.875000,0.667305,0.236102
|
||||||
1,10,9174.909180,13106.868164,0.338315
|
1,10,38804.527344,55433.980469,2.453400,0.238323
|
||||||
1,20,7635.462891,10907.713867,0.211883
|
1,20,138075.984375,197251.281250,0.288013,0.157941
|
||||||
1,30,5425.212402,7750.165039,0.323606
|
1,30,56904.078125,81291.343750,0.429261,0.174654
|
||||||
1,40,4372.716309,6246.610840,0.296102
|
1,40,10662.019531,15231.303711,0.343701,0.079876
|
||||||
1,50,3846.437988,5494.793945,0.273941
|
1,50,9890.013672,14128.458008,0.292848,0.088391
|
||||||
2,0,17057.958984,24368.269531,0.564671
|
2,0,912398.500000,1303426.125000,0.795286,0.256503
|
||||||
2,10,7089.009766,10127.021484,0.315390
|
2,10,12546.185547,17922.542969,1.328752,0.110444
|
||||||
2,20,4230.856445,6043.994141,0.201429
|
2,20,109676.710938,156680.906250,0.245523,0.132549
|
||||||
2,30,3744.107910,5348.593262,0.309118
|
2,30,49427.507812,70610.578125,0.327188,0.088591
|
||||||
2,40,3531.041992,5044.219238,0.295744
|
2,40,27778.673828,39683.660156,0.345683,0.102361
|
||||||
2,50,3570.459229,5100.528320,0.297530
|
2,50,10311.509766,14730.566406,0.350199,0.091674
|
||||||
3,0,14367.601562,20524.917969,0.529623
|
3,0,1040308.062500,1486154.000000,0.807157,0.279805
|
||||||
3,10,6734.334473,9620.347656,0.304721
|
3,10,64799.246094,92570.117188,0.485949,0.198213
|
||||||
3,20,6140.179688,8771.602539,0.193836
|
3,20,336018.000000,480025.531250,0.410466,0.118048
|
||||||
3,30,4089.454102,5841.940918,0.317864
|
3,30,94216.312500,134594.562500,0.355044,0.114209
|
||||||
3,40,3553.830811,5076.785645,0.269531
|
3,40,19988.919922,28555.457031,0.298066,0.094291
|
||||||
3,50,3590.448242,5129.088867,0.287063
|
3,50,9181.969727,13116.940430,0.326489,0.137392
|
||||||
4,0,14410.816406,20586.648438,0.543822
|
4,0,741176.187500,1058822.750000,0.850633,0.233118
|
||||||
4,10,6411.443359,9159.058594,0.341742
|
4,10,39252.617188,56074.410156,1.690607,0.227370
|
||||||
4,20,3816.795166,5452.479492,0.198213
|
4,20,108992.304688,155703.140625,0.258920,0.285944
|
||||||
4,30,4069.170898,5812.959961,0.329676
|
4,30,37115.253906,53021.632812,0.337674,0.131537
|
||||||
4,40,3484.921631,4978.332520,0.296284
|
4,40,19358.708984,27655.123047,0.349937,0.169446
|
||||||
4,50,2802.801514,4003.873779,0.299286
|
4,50,11434.291992,16334.540039,0.351904,0.089116
|
||||||
5,0,13335.293945,19050.201172,0.509557
|
5,0,845658.312500,1208083.000000,0.922483,0.285983
|
||||||
5,10,5531.156738,7901.527344,0.293409
|
5,10,130569.406250,186527.515625,0.405198,0.227639
|
||||||
5,20,3844.260010,5491.696777,0.241263
|
5,20,245780.390625,351114.687500,0.301236,0.147030
|
||||||
5,30,3619.303223,5170.297363,0.317237
|
5,30,42017.671875,60025.066406,0.372036,0.117895
|
||||||
5,40,3492.172852,4988.697754,0.281641
|
5,40,11496.740234,16423.779297,0.286911,0.085701
|
||||||
5,50,3069.457275,4384.815918,0.287269
|
5,50,8891.728516,12702.317383,0.322913,0.099181
|
||||||
6,0,8740.982422,12486.912109,0.483061
|
6,0,617909.687500,882727.812500,0.834663,0.205251
|
||||||
6,10,6110.571777,8729.239258,0.347929
|
6,10,18171.734375,25959.302734,0.677465,0.190570
|
||||||
6,20,3350.194092,4785.889160,0.239650
|
6,20,423716.187500,605308.687500,0.464189,0.156125
|
||||||
6,30,3008.237549,4297.353516,0.300327
|
6,30,48133.507812,68761.914062,0.478786,0.221465
|
||||||
6,40,2944.483887,4206.288574,0.273457
|
6,40,20350.281250,29071.666016,0.343182,0.118739
|
||||||
6,50,3018.033447,4311.365234,0.259749
|
6,50,11372.219727,16245.889648,0.289177,0.101218
|
||||||
7,0,6784.341309,9691.704102,0.495378
|
7,0,534133.500000,763047.500000,0.732465,0.239962
|
||||||
7,10,4946.872559,7066.822754,0.321541
|
7,10,19025.574219,27178.703125,1.551094,0.183023
|
||||||
7,20,2816.704102,4023.745361,0.274515
|
7,20,58042.757812,82918.078125,0.298854,0.146532
|
||||||
7,30,2991.350830,4273.229980,0.299238
|
7,30,11810.656250,16872.201172,0.353035,0.092791
|
||||||
7,40,3023.450684,4319.083984,0.307114
|
7,40,9733.656250,13905.077148,0.312043,0.089972
|
||||||
7,50,2944.912598,4206.909668,0.253002
|
7,50,9069.159180,12955.803711,0.280795,0.129749
|
||||||
8,0,7364.851562,10521.002930,0.498283
|
8,0,650366.937500,929095.250000,0.787194,0.263791
|
||||||
8,10,3897.301025,5567.439453,0.311517
|
8,10,13121.248047,18744.031250,1.376233,0.141053
|
||||||
8,20,3313.474854,4733.448242,0.203364
|
8,20,19326.507812,27609.167969,0.243424,0.183426
|
||||||
8,30,2697.139648,3852.940430,0.271355
|
8,30,12904.376953,18434.667969,0.334577,0.092547
|
||||||
8,40,2955.225342,4221.633301,0.273175
|
8,40,9682.833984,13832.478516,0.296229,0.109518
|
||||||
8,50,2932.081787,4188.575684,0.262651
|
8,50,8866.144531,12665.773438,0.312481,0.099564
|
||||||
9,0,4065.651611,5807.854004,0.513017
|
9,0,265689.562500,379556.156250,0.738840,0.186122
|
||||||
9,10,4358.108398,6225.728516,0.329111
|
9,10,9944.607422,14206.222656,0.806142,0.096763
|
||||||
9,20,3417.019043,4881.362305,0.218488
|
9,20,22091.175781,31558.679688,0.285403,0.131950
|
||||||
9,30,2917.136719,4167.226562,0.260805
|
9,30,19346.830078,27638.111328,0.460825,0.131496
|
||||||
9,40,2786.277832,3980.287109,0.256345
|
9,40,10279.702148,14685.139648,0.314872,0.095758
|
||||||
9,50,2602.660645,3717.971680,0.268723
|
9,50,10420.340820,14886.050781,0.314008,0.112904
|
||||||
|
|||||||
|
@@ -75,6 +75,7 @@ def main():
|
|||||||
run([sys.executable, str(base_dir / "evaluate_generated.py"), "--reference", str(ref)])
|
run([sys.executable, str(base_dir / "evaluate_generated.py"), "--reference", str(ref)])
|
||||||
else:
|
else:
|
||||||
run([sys.executable, str(base_dir / "evaluate_generated.py")])
|
run([sys.executable, str(base_dir / "evaluate_generated.py")])
|
||||||
|
run([sys.executable, str(base_dir / "summary_metrics.py")])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from data_utils import load_split
|
from data_utils import load_split
|
||||||
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
|
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
|
||||||
|
from hybrid_diffusion import TemporalGRUGenerator
|
||||||
from platform_utils import resolve_device, safe_path, ensure_dir
|
from platform_utils import resolve_device, safe_path, ensure_dir
|
||||||
|
|
||||||
BASE_DIR = Path(__file__).resolve().parent
|
BASE_DIR = Path(__file__).resolve().parent
|
||||||
@@ -56,6 +57,13 @@ def main():
|
|||||||
model_ff_mult = int(cfg.get("model_ff_mult", 2))
|
model_ff_mult = int(cfg.get("model_ff_mult", 2))
|
||||||
model_pos_dim = int(cfg.get("model_pos_dim", 64))
|
model_pos_dim = int(cfg.get("model_pos_dim", 64))
|
||||||
model_use_pos = bool(cfg.get("model_use_pos_embed", True))
|
model_use_pos = bool(cfg.get("model_use_pos_embed", True))
|
||||||
|
model_use_feature_graph = bool(cfg.get("model_use_feature_graph", False))
|
||||||
|
feature_graph_scale = float(cfg.get("feature_graph_scale", 0.1))
|
||||||
|
feature_graph_dropout = float(cfg.get("feature_graph_dropout", 0.0))
|
||||||
|
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
|
||||||
|
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
|
||||||
|
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
|
||||||
|
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
|
||||||
|
|
||||||
split = load_split(str(SPLIT_PATH))
|
split = load_split(str(SPLIT_PATH))
|
||||||
time_col = split.get("time_column", "time")
|
time_col = split.get("time_column", "time")
|
||||||
@@ -83,6 +91,9 @@ def main():
|
|||||||
ff_mult=model_ff_mult,
|
ff_mult=model_ff_mult,
|
||||||
pos_dim=model_pos_dim,
|
pos_dim=model_pos_dim,
|
||||||
use_pos_embed=model_use_pos,
|
use_pos_embed=model_use_pos,
|
||||||
|
use_feature_graph=model_use_feature_graph,
|
||||||
|
feature_graph_scale=feature_graph_scale,
|
||||||
|
feature_graph_dropout=feature_graph_dropout,
|
||||||
cond_vocab_size=cond_vocab_size,
|
cond_vocab_size=cond_vocab_size,
|
||||||
cond_dim=cond_dim,
|
cond_dim=cond_dim,
|
||||||
use_tanh_eps=use_tanh_eps,
|
use_tanh_eps=use_tanh_eps,
|
||||||
@@ -92,6 +103,20 @@ def main():
|
|||||||
model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True))
|
model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True))
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
temporal_model = None
|
||||||
|
if use_temporal_stage1:
|
||||||
|
temporal_model = TemporalGRUGenerator(
|
||||||
|
input_dim=len(cont_cols),
|
||||||
|
hidden_dim=temporal_hidden_dim,
|
||||||
|
num_layers=temporal_num_layers,
|
||||||
|
dropout=temporal_dropout,
|
||||||
|
).to(DEVICE)
|
||||||
|
temporal_path = BASE_DIR / "results" / "temporal.pt"
|
||||||
|
if not temporal_path.exists():
|
||||||
|
raise SystemExit(f"missing temporal model file: {temporal_path}")
|
||||||
|
temporal_model.load_state_dict(torch.load(str(temporal_path), map_location=DEVICE, weights_only=True))
|
||||||
|
temporal_model.eval()
|
||||||
|
|
||||||
betas = cosine_beta_schedule(timesteps).to(DEVICE)
|
betas = cosine_beta_schedule(timesteps).to(DEVICE)
|
||||||
alphas = 1.0 - betas
|
alphas = 1.0 - betas
|
||||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||||
@@ -110,19 +135,26 @@ def main():
|
|||||||
raise SystemExit("use_condition enabled but no files matched data_glob")
|
raise SystemExit("use_condition enabled but no files matched data_glob")
|
||||||
cond = torch.randint(0, cond_vocab_size, (batch_size,), device=DEVICE, dtype=torch.long)
|
cond = torch.randint(0, cond_vocab_size, (batch_size,), device=DEVICE, dtype=torch.long)
|
||||||
|
|
||||||
|
trend = None
|
||||||
|
if temporal_model is not None:
|
||||||
|
trend = temporal_model.generate(batch_size, seq_len, DEVICE)
|
||||||
|
|
||||||
for t in reversed(range(timesteps)):
|
for t in reversed(range(timesteps)):
|
||||||
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
|
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
|
||||||
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
||||||
|
|
||||||
|
# Continuous reverse step (DDPM): x_{t-1} mean
|
||||||
|
a_t = alphas[t]
|
||||||
|
a_bar_t = alphas_cumprod[t]
|
||||||
if cont_target == "x0":
|
if cont_target == "x0":
|
||||||
x0_pred = eps_pred
|
x0_pred = eps_pred
|
||||||
if cont_clamp_x0 > 0:
|
if cont_clamp_x0 > 0:
|
||||||
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
|
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
|
||||||
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
|
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
|
||||||
|
elif cont_target == "v":
|
||||||
# Continuous reverse step (DDPM): x_{t-1} mean
|
v_pred = eps_pred
|
||||||
a_t = alphas[t]
|
x0_pred = torch.sqrt(a_bar_t) * x_cont - torch.sqrt(1.0 - a_bar_t) * v_pred
|
||||||
a_bar_t = alphas_cumprod[t]
|
eps_pred = torch.sqrt(1.0 - a_bar_t) * x_cont + torch.sqrt(a_bar_t) * v_pred
|
||||||
coef1 = 1.0 / torch.sqrt(a_t)
|
coef1 = 1.0 / torch.sqrt(a_t)
|
||||||
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
|
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
|
||||||
mean = coef1 * (x_cont - coef2 * eps_pred)
|
mean = coef1 * (x_cont - coef2 * eps_pred)
|
||||||
@@ -146,6 +178,8 @@ def main():
|
|||||||
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(BATCH_SIZE, SEQ_LEN)
|
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(BATCH_SIZE, SEQ_LEN)
|
||||||
x_disc[:, :, i][mask] = sampled[mask]
|
x_disc[:, :, i][mask] = sampled[mask]
|
||||||
|
|
||||||
|
if trend is not None:
|
||||||
|
x_cont = x_cont + trend
|
||||||
print("sampled_cont_shape", tuple(x_cont.shape))
|
print("sampled_cont_shape", tuple(x_cont.shape))
|
||||||
print("sampled_disc_shape", tuple(x_disc.shape))
|
print("sampled_disc_shape", tuple(x_disc.shape))
|
||||||
|
|
||||||
|
|||||||
40
example/summary_metrics.py
Normal file
40
example/summary_metrics.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Print average metrics from eval.json for quick tracking."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def mean(values):
|
||||||
|
return sum(values) / len(values) if values else None
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
base_dir = Path(__file__).resolve().parent
|
||||||
|
eval_path = base_dir / "results" / "eval.json"
|
||||||
|
if not eval_path.exists():
|
||||||
|
raise SystemExit(f"missing eval.json: {eval_path}")
|
||||||
|
|
||||||
|
obj = json.loads(eval_path.read_text(encoding="utf-8"))
|
||||||
|
ks = list(obj.get("continuous_ks", {}).values())
|
||||||
|
jsd = list(obj.get("discrete_jsd", {}).values())
|
||||||
|
lag = list(obj.get("continuous_lag1_diff", {}).values())
|
||||||
|
|
||||||
|
avg_ks = mean(ks)
|
||||||
|
avg_jsd = mean(jsd)
|
||||||
|
avg_lag1 = mean(lag)
|
||||||
|
|
||||||
|
print("avg_ks", avg_ks)
|
||||||
|
print("avg_jsd", avg_jsd)
|
||||||
|
print("avg_lag1_diff", avg_lag1)
|
||||||
|
|
||||||
|
history_path = base_dir / "results" / "metrics_history.csv"
|
||||||
|
if not history_path.exists():
|
||||||
|
history_path.write_text("timestamp,avg_ks,avg_jsd,avg_lag1_diff\n", encoding="utf-8")
|
||||||
|
with history_path.open("a", encoding="utf-8") as f:
|
||||||
|
f.write(f"{datetime.utcnow().isoformat()},{avg_ks},{avg_jsd},{avg_lag1}\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
122
example/train.py
122
example/train.py
@@ -14,6 +14,7 @@ import torch.nn.functional as F
|
|||||||
from data_utils import load_split, windowed_batches
|
from data_utils import load_split, windowed_batches
|
||||||
from hybrid_diffusion import (
|
from hybrid_diffusion import (
|
||||||
HybridDiffusionModel,
|
HybridDiffusionModel,
|
||||||
|
TemporalGRUGenerator,
|
||||||
cosine_beta_schedule,
|
cosine_beta_schedule,
|
||||||
q_sample_continuous,
|
q_sample_continuous,
|
||||||
q_sample_discrete,
|
q_sample_discrete,
|
||||||
@@ -58,12 +59,25 @@ DEFAULTS = {
|
|||||||
"model_ff_mult": 2,
|
"model_ff_mult": 2,
|
||||||
"model_pos_dim": 64,
|
"model_pos_dim": 64,
|
||||||
"model_use_pos_embed": True,
|
"model_use_pos_embed": True,
|
||||||
|
"model_use_feature_graph": True,
|
||||||
|
"feature_graph_scale": 0.1,
|
||||||
|
"feature_graph_dropout": 0.0,
|
||||||
|
"use_temporal_stage1": True,
|
||||||
|
"temporal_hidden_dim": 256,
|
||||||
|
"temporal_num_layers": 1,
|
||||||
|
"temporal_dropout": 0.0,
|
||||||
|
"temporal_loss_weight": 1.0,
|
||||||
"disc_mask_scale": 0.9,
|
"disc_mask_scale": 0.9,
|
||||||
"shuffle_buffer": 256,
|
"shuffle_buffer": 256,
|
||||||
"cont_loss_weighting": "none", # none | inv_std
|
"cont_loss_weighting": "none", # none | inv_std
|
||||||
"cont_loss_eps": 1e-6,
|
"cont_loss_eps": 1e-6,
|
||||||
"cont_target": "eps", # eps | x0
|
"cont_target": "eps", # eps | x0 | v
|
||||||
"cont_clamp_x0": 0.0,
|
"cont_clamp_x0": 0.0,
|
||||||
|
"quantile_loss_weight": 0.0,
|
||||||
|
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
|
||||||
|
"quantile_loss_warmup_steps": 200,
|
||||||
|
"quantile_loss_clip": 6.0,
|
||||||
|
"quantile_loss_huber_delta": 1.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -188,12 +202,27 @@ def main():
|
|||||||
ff_mult=int(config.get("model_ff_mult", 2)),
|
ff_mult=int(config.get("model_ff_mult", 2)),
|
||||||
pos_dim=int(config.get("model_pos_dim", 64)),
|
pos_dim=int(config.get("model_pos_dim", 64)),
|
||||||
use_pos_embed=bool(config.get("model_use_pos_embed", True)),
|
use_pos_embed=bool(config.get("model_use_pos_embed", True)),
|
||||||
|
use_feature_graph=bool(config.get("model_use_feature_graph", False)),
|
||||||
|
feature_graph_scale=float(config.get("feature_graph_scale", 0.1)),
|
||||||
|
feature_graph_dropout=float(config.get("feature_graph_dropout", 0.0)),
|
||||||
cond_vocab_size=cond_vocab_size,
|
cond_vocab_size=cond_vocab_size,
|
||||||
cond_dim=int(config.get("cond_dim", 32)),
|
cond_dim=int(config.get("cond_dim", 32)),
|
||||||
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
|
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
|
||||||
eps_scale=float(config.get("eps_scale", 1.0)),
|
eps_scale=float(config.get("eps_scale", 1.0)),
|
||||||
).to(device)
|
).to(device)
|
||||||
|
temporal_model = None
|
||||||
|
if bool(config.get("use_temporal_stage1", False)):
|
||||||
|
temporal_model = TemporalGRUGenerator(
|
||||||
|
input_dim=len(cont_cols),
|
||||||
|
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
|
||||||
|
num_layers=int(config.get("temporal_num_layers", 1)),
|
||||||
|
dropout=float(config.get("temporal_dropout", 0.0)),
|
||||||
|
).to(device)
|
||||||
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
|
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
|
||||||
|
if temporal_model is not None:
|
||||||
|
opt_temporal = torch.optim.Adam(temporal_model.parameters(), lr=float(config["lr"]))
|
||||||
|
else:
|
||||||
|
opt_temporal = None
|
||||||
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
|
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
|
||||||
|
|
||||||
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
|
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
|
||||||
@@ -203,8 +232,12 @@ def main():
|
|||||||
os.makedirs(config["out_dir"], exist_ok=True)
|
os.makedirs(config["out_dir"], exist_ok=True)
|
||||||
out_dir = safe_path(config["out_dir"])
|
out_dir = safe_path(config["out_dir"])
|
||||||
log_path = os.path.join(out_dir, "train_log.csv")
|
log_path = os.path.join(out_dir, "train_log.csv")
|
||||||
|
use_quantile = float(config.get("quantile_loss_weight", 0.0)) > 0
|
||||||
with open(log_path, "w", encoding="utf-8") as f:
|
with open(log_path, "w", encoding="utf-8") as f:
|
||||||
f.write("epoch,step,loss,loss_cont,loss_disc\n")
|
if use_quantile:
|
||||||
|
f.write("epoch,step,loss,loss_cont,loss_disc,loss_quantile\n")
|
||||||
|
else:
|
||||||
|
f.write("epoch,step,loss,loss_cont,loss_disc\n")
|
||||||
with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f:
|
with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f:
|
||||||
json.dump(config, f, indent=2)
|
json.dump(config, f, indent=2)
|
||||||
|
|
||||||
@@ -235,10 +268,20 @@ def main():
|
|||||||
x_cont = x_cont.to(device)
|
x_cont = x_cont.to(device)
|
||||||
x_disc = x_disc.to(device)
|
x_disc = x_disc.to(device)
|
||||||
|
|
||||||
|
temporal_loss = None
|
||||||
|
x_cont_resid = x_cont
|
||||||
|
trend = None
|
||||||
|
if temporal_model is not None:
|
||||||
|
trend, pred_next = temporal_model.forward_teacher(x_cont)
|
||||||
|
temporal_loss = F.mse_loss(pred_next, x_cont[:, 1:, :])
|
||||||
|
temporal_loss = temporal_loss * float(config.get("temporal_loss_weight", 1.0))
|
||||||
|
trend = trend.detach()
|
||||||
|
x_cont_resid = x_cont - trend
|
||||||
|
|
||||||
bsz = x_cont.size(0)
|
bsz = x_cont.size(0)
|
||||||
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
|
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
|
||||||
|
|
||||||
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
|
x_cont_t, noise = q_sample_continuous(x_cont_resid, t, alphas_cumprod)
|
||||||
|
|
||||||
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
||||||
x_disc_t, mask = q_sample_discrete(
|
x_disc_t, mask = q_sample_discrete(
|
||||||
@@ -253,10 +296,14 @@ def main():
|
|||||||
|
|
||||||
cont_target = str(config.get("cont_target", "eps"))
|
cont_target = str(config.get("cont_target", "eps"))
|
||||||
if cont_target == "x0":
|
if cont_target == "x0":
|
||||||
x0_target = x_cont
|
x0_target = x_cont_resid
|
||||||
if float(config.get("cont_clamp_x0", 0.0)) > 0:
|
if float(config.get("cont_clamp_x0", 0.0)) > 0:
|
||||||
x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"]))
|
x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"]))
|
||||||
loss_base = (eps_pred - x0_target) ** 2
|
loss_base = (eps_pred - x0_target) ** 2
|
||||||
|
elif cont_target == "v":
|
||||||
|
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
|
||||||
|
v_target = torch.sqrt(a_bar_t) * noise - torch.sqrt(1.0 - a_bar_t) * x_cont_resid
|
||||||
|
loss_base = (eps_pred - v_target) ** 2
|
||||||
else:
|
else:
|
||||||
loss_base = (eps_pred - noise) ** 2
|
loss_base = (eps_pred - noise) ** 2
|
||||||
|
|
||||||
@@ -282,21 +329,76 @@ def main():
|
|||||||
|
|
||||||
lam = float(config["lambda"])
|
lam = float(config["lambda"])
|
||||||
loss = lam * loss_cont + (1 - lam) * loss_disc
|
loss = lam * loss_cont + (1 - lam) * loss_disc
|
||||||
|
|
||||||
|
q_weight = float(config.get("quantile_loss_weight", 0.0))
|
||||||
|
quantile_loss = 0.0
|
||||||
|
if q_weight > 0:
|
||||||
|
warmup = int(config.get("quantile_loss_warmup_steps", 0))
|
||||||
|
if warmup > 0:
|
||||||
|
q_weight = q_weight * min(1.0, (total_step + 1) / float(warmup))
|
||||||
|
q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95])
|
||||||
|
q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype)
|
||||||
|
# Use normalized space for stable quantiles on x0.
|
||||||
|
x_real = x_cont_resid
|
||||||
|
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
|
||||||
|
if cont_target == "x0":
|
||||||
|
x_gen = eps_pred
|
||||||
|
elif cont_target == "v":
|
||||||
|
v_pred = eps_pred
|
||||||
|
x_gen = torch.sqrt(a_bar_t) * x_cont_t - torch.sqrt(1.0 - a_bar_t) * v_pred
|
||||||
|
else:
|
||||||
|
# eps prediction
|
||||||
|
x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t)
|
||||||
|
q_clip = float(config.get("quantile_loss_clip", 0.0))
|
||||||
|
if q_clip > 0:
|
||||||
|
x_real = torch.clamp(x_real, -q_clip, q_clip)
|
||||||
|
x_gen = torch.clamp(x_gen, -q_clip, q_clip)
|
||||||
|
x_real = x_real.view(-1, x_real.size(-1))
|
||||||
|
x_gen = x_gen.view(-1, x_gen.size(-1))
|
||||||
|
q_real = torch.quantile(x_real, q_tensor, dim=0)
|
||||||
|
q_gen = torch.quantile(x_gen, q_tensor, dim=0)
|
||||||
|
q_delta = float(config.get("quantile_loss_huber_delta", 0.0))
|
||||||
|
q_diff = q_gen - q_real
|
||||||
|
if q_delta > 0:
|
||||||
|
quantile_loss = torch.nn.functional.smooth_l1_loss(q_gen, q_real, beta=q_delta)
|
||||||
|
else:
|
||||||
|
quantile_loss = torch.mean(torch.abs(q_diff))
|
||||||
|
loss = loss + q_weight * quantile_loss
|
||||||
|
|
||||||
opt.zero_grad()
|
opt.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if float(config.get("grad_clip", 0.0)) > 0:
|
if float(config.get("grad_clip", 0.0)) > 0:
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"]))
|
torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"]))
|
||||||
opt.step()
|
opt.step()
|
||||||
|
if opt_temporal is not None:
|
||||||
|
opt_temporal.zero_grad()
|
||||||
|
temporal_loss.backward()
|
||||||
|
if float(config.get("grad_clip", 0.0)) > 0:
|
||||||
|
torch.nn.utils.clip_grad_norm_(temporal_model.parameters(), float(config["grad_clip"]))
|
||||||
|
opt_temporal.step()
|
||||||
if ema is not None:
|
if ema is not None:
|
||||||
ema.update(model)
|
ema.update(model)
|
||||||
|
|
||||||
if step % int(config["log_every"]) == 0:
|
if step % int(config["log_every"]) == 0:
|
||||||
print("epoch", epoch, "step", step, "loss", float(loss))
|
print("epoch", epoch, "step", step, "loss", float(loss))
|
||||||
with open(log_path, "a", encoding="utf-8") as f:
|
with open(log_path, "a", encoding="utf-8") as f:
|
||||||
f.write(
|
if use_quantile:
|
||||||
"%d,%d,%.6f,%.6f,%.6f\n"
|
f.write(
|
||||||
% (epoch, step, float(loss), float(loss_cont), float(loss_disc))
|
"%d,%d,%.6f,%.6f,%.6f,%.6f\n"
|
||||||
)
|
% (
|
||||||
|
epoch,
|
||||||
|
step,
|
||||||
|
float(loss),
|
||||||
|
float(loss_cont),
|
||||||
|
float(loss_disc),
|
||||||
|
float(quantile_loss),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
f.write(
|
||||||
|
"%d,%d,%.6f,%.6f,%.6f\n"
|
||||||
|
% (epoch, step, float(loss), float(loss_cont), float(loss_disc))
|
||||||
|
)
|
||||||
|
|
||||||
total_step += 1
|
total_step += 1
|
||||||
if total_step % int(config["ckpt_every"]) == 0:
|
if total_step % int(config["ckpt_every"]) == 0:
|
||||||
@@ -308,11 +410,15 @@ def main():
|
|||||||
}
|
}
|
||||||
if ema is not None:
|
if ema is not None:
|
||||||
ckpt["ema"] = ema.state_dict()
|
ckpt["ema"] = ema.state_dict()
|
||||||
|
if temporal_model is not None:
|
||||||
|
ckpt["temporal"] = temporal_model.state_dict()
|
||||||
torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt"))
|
torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt"))
|
||||||
|
|
||||||
torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
|
torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
|
||||||
if ema is not None:
|
if ema is not None:
|
||||||
torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt"))
|
torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt"))
|
||||||
|
if temporal_model is not None:
|
||||||
|
torch.save(temporal_model.state_dict(), os.path.join(out_dir, "temporal.pt"))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
365
report.md
Normal file
365
report.md
Normal file
@@ -0,0 +1,365 @@
|
|||||||
|
# Hybrid Diffusion for ICS Traffic (HAI 21.03) — Project Report
|
||||||
|
# 工业控制系统流量混合扩散生成(HAI 21.03)— 项目报告
|
||||||
|
|
||||||
|
## 1. Project Goal / 项目目标
|
||||||
|
Build a **hybrid diffusion-based generator** for industrial control system (ICS) traffic features, targeting **mixed continuous + discrete** feature sequences. The output is **feature-level sequences**, not raw packets. The generator should preserve:
|
||||||
|
- **Distributional fidelity** (continuous value ranges and discrete frequencies)
|
||||||
|
- **Temporal consistency** (time correlation and sequence structure)
|
||||||
|
- **Protocol/field consistency** (for discrete fields)
|
||||||
|
|
||||||
|
构建一个用于工业控制系统(ICS)流量特征的**混合扩散生成模型**,面向**连续+离散混合特征序列**。输出为**特征级序列**而非原始报文。生成结果需要同时保持:
|
||||||
|
- **分布一致性**(连续值范围与离散取值频率)
|
||||||
|
- **时序一致性**(时间相关性与序列结构)
|
||||||
|
- **字段/协议一致性**(离散字段的逻辑一致)
|
||||||
|
|
||||||
|
This project is aligned with the STOUTER idea of **structure-aware diffusion** for spatiotemporal data, but applied to **ICS feature sequences** rather than cellular traffic.
|
||||||
|
|
||||||
|
本项目呼应 STOUTER 的**结构先验+扩散**思想,但应用于**ICS 特征序列**而非蜂窝流量。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Data and Scope / 数据与范围
|
||||||
|
**Dataset used in the current implementation:** HAI 21.03 (CSV feature traces)
|
||||||
|
|
||||||
|
**当前实现使用的数据集:** HAI 21.03(CSV 特征序列)
|
||||||
|
|
||||||
|
**Data location (default in config):**
|
||||||
|
- `dataset/hai/hai-21.03/train*.csv.gz`
|
||||||
|
|
||||||
|
**数据位置(config 默认):**
|
||||||
|
- `dataset/hai/hai-21.03/train*.csv.gz`
|
||||||
|
|
||||||
|
**Feature split (fixed schema):**
|
||||||
|
- Defined in `example/feature_split.json`
|
||||||
|
- **Continuous features:** sensor/process values
|
||||||
|
- **Discrete features:** binary/low-cardinality status/flag fields
|
||||||
|
- `time` column is excluded from modeling
|
||||||
|
|
||||||
|
**特征拆分(固定 schema):**
|
||||||
|
- `example/feature_split.json`
|
||||||
|
- **连续特征:** 传感器/过程值
|
||||||
|
- **离散特征:** 二值/低基数状态字段
|
||||||
|
- `time` 列不参与训练
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. End-to-End Pipeline / 端到端流程
|
||||||
|
|
||||||
|
**One command pipeline:**
|
||||||
|
```
|
||||||
|
python example/run_all.py --device cuda
|
||||||
|
```
|
||||||
|
|
||||||
|
**一键流程:**
|
||||||
|
```
|
||||||
|
python example/run_all.py --device cuda
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pipeline stages / 流程阶段
|
||||||
|
1) **Prepare data** (`example/prepare_data.py`)
|
||||||
|
2) **Train model** (`example/train.py`)
|
||||||
|
3) **Generate samples** (`example/export_samples.py`)
|
||||||
|
4) **Evaluate** (`example/evaluate_generated.py`)
|
||||||
|
5) **Summarize metrics** (`example/summary_metrics.py`)
|
||||||
|
|
||||||
|
1) **数据准备**(统计量与词表)
|
||||||
|
2) **训练模型**
|
||||||
|
3) **生成样本并导出**
|
||||||
|
4) **评估指标**
|
||||||
|
5) **汇总指标**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. Technical Architecture / 技术架构
|
||||||
|
|
||||||
|
### 4.1 Hybrid Diffusion Model (Core) / 混合扩散模型(核心)
|
||||||
|
Defined in `example/hybrid_diffusion.py`.
|
||||||
|
|
||||||
|
**Key components:**
|
||||||
|
- **Continuous branch**: Gaussian diffusion (DDPM style)
|
||||||
|
- **Discrete branch**: Mask diffusion for categorical tokens
|
||||||
|
- **Shared backbone**: GRU + residual MLP + LayerNorm
|
||||||
|
- **Embedding inputs**:
|
||||||
|
- continuous projection
|
||||||
|
- discrete embeddings per column
|
||||||
|
- time embedding (sinusoidal)
|
||||||
|
- positional embedding (sequence index)
|
||||||
|
- optional condition embedding (`file_id`)
|
||||||
|
|
||||||
|
**Outputs:**
|
||||||
|
- Continuous head: predicts target (`eps`, `x0`, or `v`)
|
||||||
|
- Discrete heads: predict logits for each discrete column
|
||||||
|
|
||||||
|
**核心组成:**
|
||||||
|
- **连续分支:** 高斯扩散(DDPM)
|
||||||
|
- **离散分支:** Mask 扩散
|
||||||
|
- **共享主干:** GRU + 残差 MLP + LayerNorm
|
||||||
|
- **输入嵌入:**
|
||||||
|
- 连续投影
|
||||||
|
- 离散字段嵌入
|
||||||
|
- 时间嵌入(正弦)
|
||||||
|
- 位置嵌入(序列索引)
|
||||||
|
- 条件嵌入(可选,`file_id`)
|
||||||
|
|
||||||
|
**输出:**
|
||||||
|
- 连续 head:预测 `eps/x0/v`
|
||||||
|
- 离散 head:各字段 logits
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 4.2 Feature Graph Mixer (Structure Prior) / 特征图混合器(结构先验)
|
||||||
|
Implemented in `example/hybrid_diffusion.py` as `FeatureGraphMixer`.
|
||||||
|
|
||||||
|
Purpose: inject **learnable feature-dependency prior** without dataset-specific hardcoding.
|
||||||
|
|
||||||
|
**Mechanism:**
|
||||||
|
- Learns a dense feature relation matrix `A`
|
||||||
|
- Applies: `x + x @ A`
|
||||||
|
- Symmetric stabilizing constraint: `(A + A^T)/2`
|
||||||
|
- Controlled by scale and dropout
|
||||||
|
|
||||||
|
**Config:**
|
||||||
|
```
|
||||||
|
"model_use_feature_graph": true,
|
||||||
|
"feature_graph_scale": 0.1,
|
||||||
|
"feature_graph_dropout": 0.0
|
||||||
|
```
|
||||||
|
|
||||||
|
**目的:**在不写死特定数据集关系的情况下,引入**可学习特征依赖先验**。
|
||||||
|
|
||||||
|
**机制:**
|
||||||
|
- 学习稠密关系矩阵 `A`
|
||||||
|
- 特征混合:`x + x @ A`
|
||||||
|
- 对称化稳定:`(A + A^T)/2`
|
||||||
|
- 通过 scale/dropout 控制强度
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 4.3 Two-Stage Temporal Backbone / 两阶段时序骨干
|
||||||
|
Stage-1 uses a **GRU temporal generator** to model sequence trend in normalized space. Stage-2 diffusion then models the **residual** (x − trend). This decouples temporal consistency from distribution alignment.
|
||||||
|
|
||||||
|
第一阶段使用 **GRU 时序生成器**在归一化空间建模序列趋势;第二阶段扩散模型学习**残差**(x − trend),实现时序一致性与分布对齐的解耦。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Diffusion Formulations / 扩散建模形式
|
||||||
|
|
||||||
|
### 5.1 Continuous Diffusion / 连续扩散
|
||||||
|
Forward process:
|
||||||
|
```
|
||||||
|
x_t = sqrt(a_bar_t) * x_0 + sqrt(1 - a_bar_t) * eps
|
||||||
|
```
|
||||||
|
|
||||||
|
Targets supported:
|
||||||
|
- **eps prediction** (standard DDPM)
|
||||||
|
- **x0 prediction** (direct reconstruction)
|
||||||
|
- **v prediction** (v = sqrt(a_bar)*eps − sqrt(1-a_bar)*x0)
|
||||||
|
|
||||||
|
Current config default:
|
||||||
|
```
|
||||||
|
"cont_target": "v"
|
||||||
|
```
|
||||||
|
|
||||||
|
Sampling uses the target to reconstruct `eps` and apply standard DDPM reverse update.
|
||||||
|
|
||||||
|
**前向扩散:**如上公式。
|
||||||
|
|
||||||
|
**支持的目标:**
|
||||||
|
- `eps`(噪声预测)
|
||||||
|
- `x0`(原样本预测)
|
||||||
|
- `v`(v‑prediction)
|
||||||
|
|
||||||
|
**当前默认:**`cont_target = v`
|
||||||
|
|
||||||
|
**采样:**根据目标反解 `eps` 再执行标准 DDPM 反向步骤。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 5.2 Discrete Diffusion (Mask) / 离散扩散(Mask)
|
||||||
|
Forward process: replace tokens with `[MASK]` using cosine schedule:
|
||||||
|
```
|
||||||
|
p(t) = 0.5 * (1 - cos(pi * t / T))
|
||||||
|
```
|
||||||
|
Optional scale: `disc_mask_scale`
|
||||||
|
|
||||||
|
Reverse process: cross-entropy on masked positions only.
|
||||||
|
|
||||||
|
**前向:**按 cosine schedule 进行 Mask。
|
||||||
|
**反向:**仅在 mask 位置计算交叉熵。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Loss Design (Current) / 当前损失设计
|
||||||
|
Total loss:
|
||||||
|
```
|
||||||
|
L = λ * L_cont + (1 − λ) * L_disc + w_q * L_quantile
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.1 Continuous Loss / 连续损失
|
||||||
|
Depending on `cont_target`:
|
||||||
|
- eps target: MSE(eps_pred, eps)
|
||||||
|
- x0 target: MSE(x0_pred, x0)
|
||||||
|
- v target: MSE(v_pred, v_target)
|
||||||
|
|
||||||
|
Optional inverse-variance weighting:
|
||||||
|
```
|
||||||
|
cont_loss_weighting = "inv_std"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.2 Discrete Loss / 离散损失
|
||||||
|
Cross-entropy on masked positions only.
|
||||||
|
|
||||||
|
### 6.3 Quantile Loss (Distribution Alignment) / 分位数损失(分布对齐)
|
||||||
|
Added to improve KS (distribution shape alignment):
|
||||||
|
- Compute quantiles on generated vs real x0
|
||||||
|
- Loss = Huber or L1 difference on quantiles
|
||||||
|
|
||||||
|
Stabilization:
|
||||||
|
```
|
||||||
|
quantile_loss_warmup_steps
|
||||||
|
quantile_loss_clip
|
||||||
|
quantile_loss_huber_delta
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. Training Strategy / 训练策略
|
||||||
|
Defined in `example/train.py`.
|
||||||
|
|
||||||
|
**Key techniques:**
|
||||||
|
- EMA of model weights
|
||||||
|
- Gradient clipping
|
||||||
|
- Shuffle buffer to reduce batch bias
|
||||||
|
- Optional feature graph prior
|
||||||
|
- Quantile loss warmup for stability
|
||||||
|
- Optional stage-1 temporal GRU (trend) + residual diffusion
|
||||||
|
|
||||||
|
**Config highlights (example/config.json):**
|
||||||
|
```
|
||||||
|
timesteps: 600
|
||||||
|
batch_size: 128
|
||||||
|
seq_len: 128
|
||||||
|
epochs: 10
|
||||||
|
max_batches: 4000
|
||||||
|
lambda: 0.7
|
||||||
|
cont_target: "v"
|
||||||
|
quantile_loss_weight: 0.1
|
||||||
|
model_use_feature_graph: true
|
||||||
|
use_temporal_stage1: true
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. Sampling & Export / 采样与导出
|
||||||
|
Defined in:
|
||||||
|
- `example/sample.py`
|
||||||
|
- `example/export_samples.py`
|
||||||
|
|
||||||
|
**Export steps:**
|
||||||
|
- Reverse diffusion with conditional sampling
|
||||||
|
- Reverse normalize continuous values
|
||||||
|
- Clamp to observed min/max
|
||||||
|
- Restore discrete tokens from vocab
|
||||||
|
- Write to CSV
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. Evaluation Metrics / 评估指标
|
||||||
|
Implemented in `example/evaluate_generated.py`.
|
||||||
|
|
||||||
|
### Continuous Metrics / 连续指标
|
||||||
|
- **KS statistic** (distribution similarity per feature)
|
||||||
|
- **Quantile errors** (q05/q25/q50/q75/q95)
|
||||||
|
- **Lag‑1 correlation diff** (temporal structure)
|
||||||
|
|
||||||
|
### Discrete Metrics / 离散指标
|
||||||
|
- **JSD** over token frequency distribution
|
||||||
|
- **Invalid token counts**
|
||||||
|
|
||||||
|
### Summary Metrics / 汇总指标
|
||||||
|
Auto-logged in:
|
||||||
|
- `example/results/metrics_history.csv`
|
||||||
|
- via `example/summary_metrics.py`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 10. Automation / 自动化
|
||||||
|
|
||||||
|
### One‑click pipeline / 一键流程
|
||||||
|
```
|
||||||
|
python example/run_all.py --device cuda
|
||||||
|
```
|
||||||
|
|
||||||
|
### Metrics logging / 指标记录
|
||||||
|
Each run appends:
|
||||||
|
```
|
||||||
|
timestamp,avg_ks,avg_jsd,avg_lag1_diff
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 11. Key Engineering Decisions / 关键工程决策
|
||||||
|
|
||||||
|
### 11.1 Mixed-Type Diffusion / 混合类型扩散
|
||||||
|
Continuous + discrete handled separately to respect data types.
|
||||||
|
|
||||||
|
### 11.2 Structure Prior / 结构先验
|
||||||
|
Learnable feature graph added to encode implicit dependencies.
|
||||||
|
|
||||||
|
### 11.3 v‑prediction
|
||||||
|
Chosen to stabilize training and improve convergence in diffusion.
|
||||||
|
|
||||||
|
### 11.4 Distribution Alignment / 分布对齐
|
||||||
|
Quantile loss introduced to directly reduce KS.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 12. Known Issues / Current Limitations / 已知问题与当前局限
|
||||||
|
- **KS remains high** in many experiments, meaning continuous distributions are still misaligned.
|
||||||
|
- **Lag‑1 may degrade** when quantile loss is too strong.
|
||||||
|
- **Loss spikes** observed when quantile loss is unstable (mitigated with warmup + clip + Huber).
|
||||||
|
|
||||||
|
**当前问题:**
|
||||||
|
- KS 高,说明连续分布仍未对齐
|
||||||
|
- 分位数损失过强时会损害时序相关性
|
||||||
|
- 分位数损失不稳定时会出现 loss 爆炸(已引入 warmup/clip/Huber)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 13. Suggested Next Steps (Research Roadmap) / 下一步建议(研究路线)
|
||||||
|
1) **SNR-weighted loss** (improve stability across timesteps)
|
||||||
|
2) **Two-stage training** (distribution first, temporal consistency second)
|
||||||
|
3) **Upgrade discrete diffusion** (D3PM-style transitions)
|
||||||
|
4) **Structured conditioning** (state/phase conditioning)
|
||||||
|
5) **Graph-based priors** (explicit feature/plant dependency graphs)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 14. Code Map (Key Files) / 代码索引(关键文件)
|
||||||
|
|
||||||
|
**Core model**
|
||||||
|
- `example/hybrid_diffusion.py`
|
||||||
|
|
||||||
|
**Training**
|
||||||
|
- `example/train.py`
|
||||||
|
|
||||||
|
**Sampling & export**
|
||||||
|
- `example/sample.py`
|
||||||
|
- `example/export_samples.py`
|
||||||
|
|
||||||
|
**Pipeline**
|
||||||
|
- `example/run_all.py`
|
||||||
|
|
||||||
|
**Evaluation**
|
||||||
|
- `example/evaluate_generated.py`
|
||||||
|
- `example/summary_metrics.py`
|
||||||
|
|
||||||
|
**Configs**
|
||||||
|
- `example/config.json`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 15. Summary / 总结
|
||||||
|
This project implements a **hybrid diffusion model for ICS traffic features**, combining continuous Gaussian diffusion with discrete mask diffusion, enhanced with a **learnable feature-graph prior**. The system includes a full pipeline for preparation, training, sampling, exporting, and evaluation. Key research challenges remain in **distribution alignment (KS)** and **joint optimization of distribution fidelity vs temporal consistency**, motivating future improvements such as SNR-weighted loss, staged training, and stronger structural priors.
|
||||||
|
|
||||||
|
本项目实现了用于 ICS 流量特征的**混合扩散模型**,将连续高斯扩散与离散 Mask 扩散结合,并引入**可学习特征图先验**。系统包含完整的数据准备、训练、采样、导出与评估流程。当前研究挑战集中在**连续分布对齐(KS)**与**分布/时序一致性之间的权衡**,后续可通过 SNR‑weighted loss、分阶段训练与更强结构先验继续改进。
|
||||||
Reference in New Issue
Block a user