update新结构

This commit is contained in:
2026-01-26 19:00:16 +08:00
parent f8edee9510
commit dd4c1e171f
9 changed files with 545 additions and 5 deletions

View File

@@ -73,3 +73,4 @@ python example/run_pipeline.py --device auto
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it. - `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.
- `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU. - `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU.
- Optional feature-graph mixer (`model_use_feature_graph`) adds a learnable relation prior across feature channels. - Optional feature-graph mixer (`model_use_feature_graph`) adds a learnable relation prior across feature channels.
- Optional two-stage temporal backbone (`use_temporal_stage1`) learns a GRU-based sequence trend; diffusion models the residual.

View File

@@ -35,6 +35,11 @@
"model_use_feature_graph": true, "model_use_feature_graph": true,
"feature_graph_scale": 0.1, "feature_graph_scale": 0.1,
"feature_graph_dropout": 0.0, "feature_graph_dropout": 0.0,
"use_temporal_stage1": true,
"temporal_hidden_dim": 256,
"temporal_num_layers": 1,
"temporal_dropout": 0.0,
"temporal_loss_weight": 1.0,
"disc_mask_scale": 0.9, "disc_mask_scale": 0.9,
"cont_loss_weighting": "inv_std", "cont_loss_weighting": "inv_std",
"cont_loss_eps": 1e-6, "cont_loss_eps": 1e-6,

View File

@@ -13,7 +13,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from data_utils import load_split from data_utils import load_split
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule
from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path
@@ -140,6 +140,10 @@ def main():
raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob) raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob)
cont_target = str(cfg.get("cont_target", "eps")) cont_target = str(cfg.get("cont_target", "eps"))
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0)) cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
model = HybridDiffusionModel( model = HybridDiffusionModel(
cont_dim=len(cont_cols), cont_dim=len(cont_cols),
@@ -166,6 +170,20 @@ def main():
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True)) model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
model.eval() model.eval()
temporal_model = None
if use_temporal_stage1:
temporal_model = TemporalGRUGenerator(
input_dim=len(cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_num_layers,
dropout=temporal_dropout,
).to(device)
temporal_path = Path(args.model_path).with_name("temporal.pt")
if not temporal_path.exists():
raise SystemExit(f"missing temporal model file: {temporal_path}")
temporal_model.load_state_dict(torch.load(temporal_path, map_location=device, weights_only=True))
temporal_model.eval()
betas = cosine_beta_schedule(args.timesteps).to(device) betas = cosine_beta_schedule(args.timesteps).to(device)
alphas = 1.0 - betas alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod = torch.cumprod(alphas, dim=0)
@@ -192,6 +210,10 @@ def main():
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long) cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
cond = cond_id cond = cond_id
trend = None
if temporal_model is not None:
trend = temporal_model.generate(args.batch_size, args.seq_len, device)
for t in reversed(range(args.timesteps)): for t in reversed(range(args.timesteps)):
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long) t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch, cond) eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
@@ -232,6 +254,8 @@ def main():
) )
x_disc[:, :, i][mask] = sampled[mask] x_disc[:, :, i][mask] = sampled[mask]
if trend is not None:
x_cont = x_cont + trend
# move to CPU for export # move to CPU for export
x_cont = x_cont.cpu() x_cont = x_cont.cpu()
x_disc = x_disc.cpu() x_disc = x_disc.cpu()

View File

@@ -84,6 +84,46 @@ class FeatureGraphMixer(nn.Module):
return x + mixed return x + mixed
class TemporalGRUGenerator(nn.Module):
"""Stage-1 temporal generator (autoregressive GRU) for sequence backbone."""
def __init__(self, input_dim: int, hidden_dim: int = 256, num_layers: int = 1, dropout: float = 0.0):
super().__init__()
self.start_token = nn.Parameter(torch.zeros(input_dim))
self.gru = nn.GRU(
input_dim,
hidden_dim,
num_layers=num_layers,
dropout=dropout if num_layers > 1 else 0.0,
batch_first=True,
)
self.out = nn.Linear(hidden_dim, input_dim)
def forward_teacher(self, x: torch.Tensor) -> torch.Tensor:
"""Teacher-forced next-step prediction. Returns trend sequence and preds."""
if x.size(1) < 2:
raise ValueError("sequence length must be >= 2 for teacher forcing")
inp = x[:, :-1, :]
out, _ = self.gru(inp)
pred_next = self.out(out)
trend = torch.zeros_like(x)
trend[:, 0, :] = x[:, 0, :]
trend[:, 1:, :] = pred_next
return trend, pred_next
def generate(self, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
"""Autoregressively generate a backbone sequence."""
h = None
prev = self.start_token.unsqueeze(0).expand(batch_size, -1).to(device)
outputs = []
for _ in range(seq_len):
out, h = self.gru(prev.unsqueeze(1), h)
nxt = self.out(out.squeeze(1))
outputs.append(nxt.unsqueeze(1))
prev = nxt
return torch.cat(outputs, dim=1)
class HybridDiffusionModel(nn.Module): class HybridDiffusionModel(nn.Module):
def __init__( def __init__(
self, self,

View File

@@ -75,6 +75,7 @@ def main():
run([sys.executable, str(base_dir / "evaluate_generated.py"), "--reference", str(ref)]) run([sys.executable, str(base_dir / "evaluate_generated.py"), "--reference", str(ref)])
else: else:
run([sys.executable, str(base_dir / "evaluate_generated.py")]) run([sys.executable, str(base_dir / "evaluate_generated.py")])
run([sys.executable, str(base_dir / "summary_metrics.py")])
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -11,6 +11,7 @@ import torch.nn.functional as F
from data_utils import load_split from data_utils import load_split
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
from hybrid_diffusion import TemporalGRUGenerator
from platform_utils import resolve_device, safe_path, ensure_dir from platform_utils import resolve_device, safe_path, ensure_dir
BASE_DIR = Path(__file__).resolve().parent BASE_DIR = Path(__file__).resolve().parent
@@ -59,6 +60,10 @@ def main():
model_use_feature_graph = bool(cfg.get("model_use_feature_graph", False)) model_use_feature_graph = bool(cfg.get("model_use_feature_graph", False))
feature_graph_scale = float(cfg.get("feature_graph_scale", 0.1)) feature_graph_scale = float(cfg.get("feature_graph_scale", 0.1))
feature_graph_dropout = float(cfg.get("feature_graph_dropout", 0.0)) feature_graph_dropout = float(cfg.get("feature_graph_dropout", 0.0))
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
split = load_split(str(SPLIT_PATH)) split = load_split(str(SPLIT_PATH))
time_col = split.get("time_column", "time") time_col = split.get("time_column", "time")
@@ -98,6 +103,20 @@ def main():
model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True)) model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True))
model.eval() model.eval()
temporal_model = None
if use_temporal_stage1:
temporal_model = TemporalGRUGenerator(
input_dim=len(cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_num_layers,
dropout=temporal_dropout,
).to(DEVICE)
temporal_path = BASE_DIR / "results" / "temporal.pt"
if not temporal_path.exists():
raise SystemExit(f"missing temporal model file: {temporal_path}")
temporal_model.load_state_dict(torch.load(str(temporal_path), map_location=DEVICE, weights_only=True))
temporal_model.eval()
betas = cosine_beta_schedule(timesteps).to(DEVICE) betas = cosine_beta_schedule(timesteps).to(DEVICE)
alphas = 1.0 - betas alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod = torch.cumprod(alphas, dim=0)
@@ -116,6 +135,10 @@ def main():
raise SystemExit("use_condition enabled but no files matched data_glob") raise SystemExit("use_condition enabled but no files matched data_glob")
cond = torch.randint(0, cond_vocab_size, (batch_size,), device=DEVICE, dtype=torch.long) cond = torch.randint(0, cond_vocab_size, (batch_size,), device=DEVICE, dtype=torch.long)
trend = None
if temporal_model is not None:
trend = temporal_model.generate(batch_size, seq_len, DEVICE)
for t in reversed(range(timesteps)): for t in reversed(range(timesteps)):
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long) t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch, cond) eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
@@ -155,6 +178,8 @@ def main():
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(BATCH_SIZE, SEQ_LEN) sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(BATCH_SIZE, SEQ_LEN)
x_disc[:, :, i][mask] = sampled[mask] x_disc[:, :, i][mask] = sampled[mask]
if trend is not None:
x_cont = x_cont + trend
print("sampled_cont_shape", tuple(x_cont.shape)) print("sampled_cont_shape", tuple(x_cont.shape))
print("sampled_disc_shape", tuple(x_disc.shape)) print("sampled_disc_shape", tuple(x_disc.shape))

View File

@@ -0,0 +1,40 @@
#!/usr/bin/env python3
"""Print average metrics from eval.json for quick tracking."""
import json
from datetime import datetime
from pathlib import Path
def mean(values):
return sum(values) / len(values) if values else None
def main():
base_dir = Path(__file__).resolve().parent
eval_path = base_dir / "results" / "eval.json"
if not eval_path.exists():
raise SystemExit(f"missing eval.json: {eval_path}")
obj = json.loads(eval_path.read_text(encoding="utf-8"))
ks = list(obj.get("continuous_ks", {}).values())
jsd = list(obj.get("discrete_jsd", {}).values())
lag = list(obj.get("continuous_lag1_diff", {}).values())
avg_ks = mean(ks)
avg_jsd = mean(jsd)
avg_lag1 = mean(lag)
print("avg_ks", avg_ks)
print("avg_jsd", avg_jsd)
print("avg_lag1_diff", avg_lag1)
history_path = base_dir / "results" / "metrics_history.csv"
if not history_path.exists():
history_path.write_text("timestamp,avg_ks,avg_jsd,avg_lag1_diff\n", encoding="utf-8")
with history_path.open("a", encoding="utf-8") as f:
f.write(f"{datetime.utcnow().isoformat()},{avg_ks},{avg_jsd},{avg_lag1}\n")
if __name__ == "__main__":
main()

View File

@@ -14,6 +14,7 @@ import torch.nn.functional as F
from data_utils import load_split, windowed_batches from data_utils import load_split, windowed_batches
from hybrid_diffusion import ( from hybrid_diffusion import (
HybridDiffusionModel, HybridDiffusionModel,
TemporalGRUGenerator,
cosine_beta_schedule, cosine_beta_schedule,
q_sample_continuous, q_sample_continuous,
q_sample_discrete, q_sample_discrete,
@@ -61,6 +62,11 @@ DEFAULTS = {
"model_use_feature_graph": True, "model_use_feature_graph": True,
"feature_graph_scale": 0.1, "feature_graph_scale": 0.1,
"feature_graph_dropout": 0.0, "feature_graph_dropout": 0.0,
"use_temporal_stage1": True,
"temporal_hidden_dim": 256,
"temporal_num_layers": 1,
"temporal_dropout": 0.0,
"temporal_loss_weight": 1.0,
"disc_mask_scale": 0.9, "disc_mask_scale": 0.9,
"shuffle_buffer": 256, "shuffle_buffer": 256,
"cont_loss_weighting": "none", # none | inv_std "cont_loss_weighting": "none", # none | inv_std
@@ -204,7 +210,19 @@ def main():
use_tanh_eps=bool(config.get("use_tanh_eps", False)), use_tanh_eps=bool(config.get("use_tanh_eps", False)),
eps_scale=float(config.get("eps_scale", 1.0)), eps_scale=float(config.get("eps_scale", 1.0)),
).to(device) ).to(device)
temporal_model = None
if bool(config.get("use_temporal_stage1", False)):
temporal_model = TemporalGRUGenerator(
input_dim=len(cont_cols),
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
num_layers=int(config.get("temporal_num_layers", 1)),
dropout=float(config.get("temporal_dropout", 0.0)),
).to(device)
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"])) opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
if temporal_model is not None:
opt_temporal = torch.optim.Adam(temporal_model.parameters(), lr=float(config["lr"]))
else:
opt_temporal = None
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
betas = cosine_beta_schedule(int(config["timesteps"])).to(device) betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
@@ -250,10 +268,20 @@ def main():
x_cont = x_cont.to(device) x_cont = x_cont.to(device)
x_disc = x_disc.to(device) x_disc = x_disc.to(device)
temporal_loss = None
x_cont_resid = x_cont
trend = None
if temporal_model is not None:
trend, pred_next = temporal_model.forward_teacher(x_cont)
temporal_loss = F.mse_loss(pred_next, x_cont[:, 1:, :])
temporal_loss = temporal_loss * float(config.get("temporal_loss_weight", 1.0))
trend = trend.detach()
x_cont_resid = x_cont - trend
bsz = x_cont.size(0) bsz = x_cont.size(0)
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device) t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod) x_cont_t, noise = q_sample_continuous(x_cont_resid, t, alphas_cumprod)
mask_tokens = torch.tensor(vocab_sizes, device=device) mask_tokens = torch.tensor(vocab_sizes, device=device)
x_disc_t, mask = q_sample_discrete( x_disc_t, mask = q_sample_discrete(
@@ -268,13 +296,13 @@ def main():
cont_target = str(config.get("cont_target", "eps")) cont_target = str(config.get("cont_target", "eps"))
if cont_target == "x0": if cont_target == "x0":
x0_target = x_cont x0_target = x_cont_resid
if float(config.get("cont_clamp_x0", 0.0)) > 0: if float(config.get("cont_clamp_x0", 0.0)) > 0:
x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"])) x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"]))
loss_base = (eps_pred - x0_target) ** 2 loss_base = (eps_pred - x0_target) ** 2
elif cont_target == "v": elif cont_target == "v":
a_bar_t = alphas_cumprod[t].view(-1, 1, 1) a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
v_target = torch.sqrt(a_bar_t) * noise - torch.sqrt(1.0 - a_bar_t) * x_cont v_target = torch.sqrt(a_bar_t) * noise - torch.sqrt(1.0 - a_bar_t) * x_cont_resid
loss_base = (eps_pred - v_target) ** 2 loss_base = (eps_pred - v_target) ** 2
else: else:
loss_base = (eps_pred - noise) ** 2 loss_base = (eps_pred - noise) ** 2
@@ -311,7 +339,7 @@ def main():
q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95]) q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95])
q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype) q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype)
# Use normalized space for stable quantiles on x0. # Use normalized space for stable quantiles on x0.
x_real = x_cont x_real = x_cont_resid
a_bar_t = alphas_cumprod[t].view(-1, 1, 1) a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
if cont_target == "x0": if cont_target == "x0":
x_gen = eps_pred x_gen = eps_pred
@@ -336,11 +364,18 @@ def main():
else: else:
quantile_loss = torch.mean(torch.abs(q_diff)) quantile_loss = torch.mean(torch.abs(q_diff))
loss = loss + q_weight * quantile_loss loss = loss + q_weight * quantile_loss
opt.zero_grad() opt.zero_grad()
loss.backward() loss.backward()
if float(config.get("grad_clip", 0.0)) > 0: if float(config.get("grad_clip", 0.0)) > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"])) torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"]))
opt.step() opt.step()
if opt_temporal is not None:
opt_temporal.zero_grad()
temporal_loss.backward()
if float(config.get("grad_clip", 0.0)) > 0:
torch.nn.utils.clip_grad_norm_(temporal_model.parameters(), float(config["grad_clip"]))
opt_temporal.step()
if ema is not None: if ema is not None:
ema.update(model) ema.update(model)
@@ -375,11 +410,15 @@ def main():
} }
if ema is not None: if ema is not None:
ckpt["ema"] = ema.state_dict() ckpt["ema"] = ema.state_dict()
if temporal_model is not None:
ckpt["temporal"] = temporal_model.state_dict()
torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt")) torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt"))
torch.save(model.state_dict(), os.path.join(out_dir, "model.pt")) torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
if ema is not None: if ema is not None:
torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt")) torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt"))
if temporal_model is not None:
torch.save(temporal_model.state_dict(), os.path.join(out_dir, "temporal.pt"))
if __name__ == "__main__": if __name__ == "__main__":

365
report.md Normal file
View File

@@ -0,0 +1,365 @@
# Hybrid Diffusion for ICS Traffic (HAI 21.03) — Project Report
# 工业控制系统流量混合扩散生成HAI 21.03)— 项目报告
## 1. Project Goal / 项目目标
Build a **hybrid diffusion-based generator** for industrial control system (ICS) traffic features, targeting **mixed continuous + discrete** feature sequences. The output is **feature-level sequences**, not raw packets. The generator should preserve:
- **Distributional fidelity** (continuous value ranges and discrete frequencies)
- **Temporal consistency** (time correlation and sequence structure)
- **Protocol/field consistency** (for discrete fields)
构建一个用于工业控制系统ICS流量特征的**混合扩散生成模型**,面向**连续+离散混合特征序列**。输出为**特征级序列**而非原始报文。生成结果需要同时保持:
- **分布一致性**(连续值范围与离散取值频率)
- **时序一致性**(时间相关性与序列结构)
- **字段/协议一致性**(离散字段的逻辑一致)
This project is aligned with the STOUTER idea of **structure-aware diffusion** for spatiotemporal data, but applied to **ICS feature sequences** rather than cellular traffic.
本项目呼应 STOUTER 的**结构先验+扩散**思想,但应用于**ICS 特征序列**而非蜂窝流量。
---
## 2. Data and Scope / 数据与范围
**Dataset used in the current implementation:** HAI 21.03 (CSV feature traces)
**当前实现使用的数据集:** HAI 21.03CSV 特征序列)
**Data location (default in config):**
- `dataset/hai/hai-21.03/train*.csv.gz`
**数据位置config 默认):**
- `dataset/hai/hai-21.03/train*.csv.gz`
**Feature split (fixed schema):**
- Defined in `example/feature_split.json`
- **Continuous features:** sensor/process values
- **Discrete features:** binary/low-cardinality status/flag fields
- `time` column is excluded from modeling
**特征拆分(固定 schema**
- `example/feature_split.json`
- **连续特征:** 传感器/过程值
- **离散特征:** 二值/低基数状态字段
- `time` 列不参与训练
---
## 3. End-to-End Pipeline / 端到端流程
**One command pipeline:**
```
python example/run_all.py --device cuda
```
**一键流程:**
```
python example/run_all.py --device cuda
```
### Pipeline stages / 流程阶段
1) **Prepare data** (`example/prepare_data.py`)
2) **Train model** (`example/train.py`)
3) **Generate samples** (`example/export_samples.py`)
4) **Evaluate** (`example/evaluate_generated.py`)
5) **Summarize metrics** (`example/summary_metrics.py`)
1) **数据准备**(统计量与词表)
2) **训练模型**
3) **生成样本并导出**
4) **评估指标**
5) **汇总指标**
---
## 4. Technical Architecture / 技术架构
### 4.1 Hybrid Diffusion Model (Core) / 混合扩散模型(核心)
Defined in `example/hybrid_diffusion.py`.
**Key components:**
- **Continuous branch**: Gaussian diffusion (DDPM style)
- **Discrete branch**: Mask diffusion for categorical tokens
- **Shared backbone**: GRU + residual MLP + LayerNorm
- **Embedding inputs**:
- continuous projection
- discrete embeddings per column
- time embedding (sinusoidal)
- positional embedding (sequence index)
- optional condition embedding (`file_id`)
**Outputs:**
- Continuous head: predicts target (`eps`, `x0`, or `v`)
- Discrete heads: predict logits for each discrete column
**核心组成:**
- **连续分支:** 高斯扩散DDPM
- **离散分支:** Mask 扩散
- **共享主干:** GRU + 残差 MLP + LayerNorm
- **输入嵌入:**
- 连续投影
- 离散字段嵌入
- 时间嵌入(正弦)
- 位置嵌入(序列索引)
- 条件嵌入(可选,`file_id`
**输出:**
- 连续 head预测 `eps/x0/v`
- 离散 head各字段 logits
---
### 4.2 Feature Graph Mixer (Structure Prior) / 特征图混合器(结构先验)
Implemented in `example/hybrid_diffusion.py` as `FeatureGraphMixer`.
Purpose: inject **learnable feature-dependency prior** without dataset-specific hardcoding.
**Mechanism:**
- Learns a dense feature relation matrix `A`
- Applies: `x + x @ A`
- Symmetric stabilizing constraint: `(A + A^T)/2`
- Controlled by scale and dropout
**Config:**
```
"model_use_feature_graph": true,
"feature_graph_scale": 0.1,
"feature_graph_dropout": 0.0
```
**目的:**在不写死特定数据集关系的情况下,引入**可学习特征依赖先验**
**机制:**
- 学习稠密关系矩阵 `A`
- 特征混合:`x + x @ A`
- 对称化稳定:`(A + A^T)/2`
- 通过 scale/dropout 控制强度
---
### 4.3 Two-Stage Temporal Backbone / 两阶段时序骨干
Stage-1 uses a **GRU temporal generator** to model sequence trend in normalized space. Stage-2 diffusion then models the **residual** (x trend). This decouples temporal consistency from distribution alignment.
第一阶段使用 **GRU 时序生成器**在归一化空间建模序列趋势;第二阶段扩散模型学习**残差**x trend实现时序一致性与分布对齐的解耦。
---
## 5. Diffusion Formulations / 扩散建模形式
### 5.1 Continuous Diffusion / 连续扩散
Forward process:
```
x_t = sqrt(a_bar_t) * x_0 + sqrt(1 - a_bar_t) * eps
```
Targets supported:
- **eps prediction** (standard DDPM)
- **x0 prediction** (direct reconstruction)
- **v prediction** (v = sqrt(a_bar)*eps sqrt(1-a_bar)*x0)
Current config default:
```
"cont_target": "v"
```
Sampling uses the target to reconstruct `eps` and apply standard DDPM reverse update.
**前向扩散:**如上公式。
**支持的目标:**
- `eps`(噪声预测)
- `x0`(原样本预测)
- `v`vprediction
**当前默认:**`cont_target = v`
**采样:**根据目标反解 `eps` 再执行标准 DDPM 反向步骤。
---
### 5.2 Discrete Diffusion (Mask) / 离散扩散Mask
Forward process: replace tokens with `[MASK]` using cosine schedule:
```
p(t) = 0.5 * (1 - cos(pi * t / T))
```
Optional scale: `disc_mask_scale`
Reverse process: cross-entropy on masked positions only.
**前向:**按 cosine schedule 进行 Mask。
**反向:**仅在 mask 位置计算交叉熵。
---
## 6. Loss Design (Current) / 当前损失设计
Total loss:
```
L = λ * L_cont + (1 λ) * L_disc + w_q * L_quantile
```
### 6.1 Continuous Loss / 连续损失
Depending on `cont_target`:
- eps target: MSE(eps_pred, eps)
- x0 target: MSE(x0_pred, x0)
- v target: MSE(v_pred, v_target)
Optional inverse-variance weighting:
```
cont_loss_weighting = "inv_std"
```
### 6.2 Discrete Loss / 离散损失
Cross-entropy on masked positions only.
### 6.3 Quantile Loss (Distribution Alignment) / 分位数损失(分布对齐)
Added to improve KS (distribution shape alignment):
- Compute quantiles on generated vs real x0
- Loss = Huber or L1 difference on quantiles
Stabilization:
```
quantile_loss_warmup_steps
quantile_loss_clip
quantile_loss_huber_delta
```
---
## 7. Training Strategy / 训练策略
Defined in `example/train.py`.
**Key techniques:**
- EMA of model weights
- Gradient clipping
- Shuffle buffer to reduce batch bias
- Optional feature graph prior
- Quantile loss warmup for stability
- Optional stage-1 temporal GRU (trend) + residual diffusion
**Config highlights (example/config.json):**
```
timesteps: 600
batch_size: 128
seq_len: 128
epochs: 10
max_batches: 4000
lambda: 0.7
cont_target: "v"
quantile_loss_weight: 0.1
model_use_feature_graph: true
use_temporal_stage1: true
```
---
## 8. Sampling & Export / 采样与导出
Defined in:
- `example/sample.py`
- `example/export_samples.py`
**Export steps:**
- Reverse diffusion with conditional sampling
- Reverse normalize continuous values
- Clamp to observed min/max
- Restore discrete tokens from vocab
- Write to CSV
---
## 9. Evaluation Metrics / 评估指标
Implemented in `example/evaluate_generated.py`.
### Continuous Metrics / 连续指标
- **KS statistic** (distribution similarity per feature)
- **Quantile errors** (q05/q25/q50/q75/q95)
- **Lag1 correlation diff** (temporal structure)
### Discrete Metrics / 离散指标
- **JSD** over token frequency distribution
- **Invalid token counts**
### Summary Metrics / 汇总指标
Auto-logged in:
- `example/results/metrics_history.csv`
- via `example/summary_metrics.py`
---
## 10. Automation / 自动化
### Oneclick pipeline / 一键流程
```
python example/run_all.py --device cuda
```
### Metrics logging / 指标记录
Each run appends:
```
timestamp,avg_ks,avg_jsd,avg_lag1_diff
```
---
## 11. Key Engineering Decisions / 关键工程决策
### 11.1 Mixed-Type Diffusion / 混合类型扩散
Continuous + discrete handled separately to respect data types.
### 11.2 Structure Prior / 结构先验
Learnable feature graph added to encode implicit dependencies.
### 11.3 vprediction
Chosen to stabilize training and improve convergence in diffusion.
### 11.4 Distribution Alignment / 分布对齐
Quantile loss introduced to directly reduce KS.
---
## 12. Known Issues / Current Limitations / 已知问题与当前局限
- **KS remains high** in many experiments, meaning continuous distributions are still misaligned.
- **Lag1 may degrade** when quantile loss is too strong.
- **Loss spikes** observed when quantile loss is unstable (mitigated with warmup + clip + Huber).
**当前问题:**
- KS 高,说明连续分布仍未对齐
- 分位数损失过强时会损害时序相关性
- 分位数损失不稳定时会出现 loss 爆炸(已引入 warmup/clip/Huber
---
## 13. Suggested Next Steps (Research Roadmap) / 下一步建议(研究路线)
1) **SNR-weighted loss** (improve stability across timesteps)
2) **Two-stage training** (distribution first, temporal consistency second)
3) **Upgrade discrete diffusion** (D3PM-style transitions)
4) **Structured conditioning** (state/phase conditioning)
5) **Graph-based priors** (explicit feature/plant dependency graphs)
---
## 14. Code Map (Key Files) / 代码索引(关键文件)
**Core model**
- `example/hybrid_diffusion.py`
**Training**
- `example/train.py`
**Sampling & export**
- `example/sample.py`
- `example/export_samples.py`
**Pipeline**
- `example/run_all.py`
**Evaluation**
- `example/evaluate_generated.py`
- `example/summary_metrics.py`
**Configs**
- `example/config.json`
---
## 15. Summary / 总结
This project implements a **hybrid diffusion model for ICS traffic features**, combining continuous Gaussian diffusion with discrete mask diffusion, enhanced with a **learnable feature-graph prior**. The system includes a full pipeline for preparation, training, sampling, exporting, and evaluation. Key research challenges remain in **distribution alignment (KS)** and **joint optimization of distribution fidelity vs temporal consistency**, motivating future improvements such as SNR-weighted loss, staged training, and stronger structural priors.
本项目实现了用于 ICS 流量特征的**混合扩散模型**,将连续高斯扩散与离散 Mask 扩散结合,并引入**可学习特征图先验**。系统包含完整的数据准备、训练、采样、导出与评估流程。当前研究挑战集中在**连续分布对齐KS**与**分布/时序一致性之间的权衡**,后续可通过 SNRweighted loss、分阶段训练与更强结构先验继续改进。