From dd4c1e171f9d45ae3d8a7c4b61fc25336bb8856c Mon Sep 17 00:00:00 2001 From: MingzheYang Date: Mon, 26 Jan 2026 19:00:16 +0800 Subject: [PATCH] =?UTF-8?q?update=E6=96=B0=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- example/README.md | 1 + example/config.json | 5 + example/export_samples.py | 26 ++- example/hybrid_diffusion.py | 40 ++++ example/run_all.py | 1 + example/sample.py | 25 +++ example/summary_metrics.py | 40 ++++ example/train.py | 47 ++++- report.md | 365 ++++++++++++++++++++++++++++++++++++ 9 files changed, 545 insertions(+), 5 deletions(-) create mode 100644 example/summary_metrics.py create mode 100644 report.md diff --git a/example/README.md b/example/README.md index 04d2cb4..33e795b 100644 --- a/example/README.md +++ b/example/README.md @@ -73,3 +73,4 @@ python example/run_pipeline.py --device auto - `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it. - `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU. - Optional feature-graph mixer (`model_use_feature_graph`) adds a learnable relation prior across feature channels. +- Optional two-stage temporal backbone (`use_temporal_stage1`) learns a GRU-based sequence trend; diffusion models the residual. diff --git a/example/config.json b/example/config.json index 8e81f00..daa574d 100644 --- a/example/config.json +++ b/example/config.json @@ -35,6 +35,11 @@ "model_use_feature_graph": true, "feature_graph_scale": 0.1, "feature_graph_dropout": 0.0, + "use_temporal_stage1": true, + "temporal_hidden_dim": 256, + "temporal_num_layers": 1, + "temporal_dropout": 0.0, + "temporal_loss_weight": 1.0, "disc_mask_scale": 0.9, "cont_loss_weighting": "inv_std", "cont_loss_eps": 1e-6, diff --git a/example/export_samples.py b/example/export_samples.py index 0809bd9..7c01704 100644 --- a/example/export_samples.py +++ b/example/export_samples.py @@ -13,7 +13,7 @@ import torch import torch.nn.functional as F from data_utils import load_split -from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule +from hybrid_diffusion import HybridDiffusionModel, TemporalGRUGenerator, cosine_beta_schedule from platform_utils import resolve_device, safe_path, ensure_dir, resolve_path @@ -140,6 +140,10 @@ def main(): raise SystemExit("use_condition enabled but no files matched data_glob: %s" % cfg_glob) cont_target = str(cfg.get("cont_target", "eps")) cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0)) + use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False)) + temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256)) + temporal_num_layers = int(cfg.get("temporal_num_layers", 1)) + temporal_dropout = float(cfg.get("temporal_dropout", 0.0)) model = HybridDiffusionModel( cont_dim=len(cont_cols), @@ -166,6 +170,20 @@ def main(): model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True)) model.eval() + temporal_model = None + if use_temporal_stage1: + temporal_model = TemporalGRUGenerator( + input_dim=len(cont_cols), + hidden_dim=temporal_hidden_dim, + num_layers=temporal_num_layers, + dropout=temporal_dropout, + ).to(device) + temporal_path = Path(args.model_path).with_name("temporal.pt") + if not temporal_path.exists(): + raise SystemExit(f"missing temporal model file: {temporal_path}") + temporal_model.load_state_dict(torch.load(temporal_path, map_location=device, weights_only=True)) + temporal_model.eval() + betas = cosine_beta_schedule(args.timesteps).to(device) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) @@ -192,6 +210,10 @@ def main(): cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long) cond = cond_id + trend = None + if temporal_model is not None: + trend = temporal_model.generate(args.batch_size, args.seq_len, device) + for t in reversed(range(args.timesteps)): t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long) eps_pred, logits = model(x_cont, x_disc, t_batch, cond) @@ -232,6 +254,8 @@ def main(): ) x_disc[:, :, i][mask] = sampled[mask] + if trend is not None: + x_cont = x_cont + trend # move to CPU for export x_cont = x_cont.cpu() x_disc = x_disc.cpu() diff --git a/example/hybrid_diffusion.py b/example/hybrid_diffusion.py index 4245d91..4482bb6 100755 --- a/example/hybrid_diffusion.py +++ b/example/hybrid_diffusion.py @@ -84,6 +84,46 @@ class FeatureGraphMixer(nn.Module): return x + mixed +class TemporalGRUGenerator(nn.Module): + """Stage-1 temporal generator (autoregressive GRU) for sequence backbone.""" + + def __init__(self, input_dim: int, hidden_dim: int = 256, num_layers: int = 1, dropout: float = 0.0): + super().__init__() + self.start_token = nn.Parameter(torch.zeros(input_dim)) + self.gru = nn.GRU( + input_dim, + hidden_dim, + num_layers=num_layers, + dropout=dropout if num_layers > 1 else 0.0, + batch_first=True, + ) + self.out = nn.Linear(hidden_dim, input_dim) + + def forward_teacher(self, x: torch.Tensor) -> torch.Tensor: + """Teacher-forced next-step prediction. Returns trend sequence and preds.""" + if x.size(1) < 2: + raise ValueError("sequence length must be >= 2 for teacher forcing") + inp = x[:, :-1, :] + out, _ = self.gru(inp) + pred_next = self.out(out) + trend = torch.zeros_like(x) + trend[:, 0, :] = x[:, 0, :] + trend[:, 1:, :] = pred_next + return trend, pred_next + + def generate(self, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor: + """Autoregressively generate a backbone sequence.""" + h = None + prev = self.start_token.unsqueeze(0).expand(batch_size, -1).to(device) + outputs = [] + for _ in range(seq_len): + out, h = self.gru(prev.unsqueeze(1), h) + nxt = self.out(out.squeeze(1)) + outputs.append(nxt.unsqueeze(1)) + prev = nxt + return torch.cat(outputs, dim=1) + + class HybridDiffusionModel(nn.Module): def __init__( self, diff --git a/example/run_all.py b/example/run_all.py index 7926852..e5dbbcd 100644 --- a/example/run_all.py +++ b/example/run_all.py @@ -75,6 +75,7 @@ def main(): run([sys.executable, str(base_dir / "evaluate_generated.py"), "--reference", str(ref)]) else: run([sys.executable, str(base_dir / "evaluate_generated.py")]) + run([sys.executable, str(base_dir / "summary_metrics.py")]) if __name__ == "__main__": diff --git a/example/sample.py b/example/sample.py index 14cda54..f6900b3 100755 --- a/example/sample.py +++ b/example/sample.py @@ -11,6 +11,7 @@ import torch.nn.functional as F from data_utils import load_split from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule +from hybrid_diffusion import TemporalGRUGenerator from platform_utils import resolve_device, safe_path, ensure_dir BASE_DIR = Path(__file__).resolve().parent @@ -59,6 +60,10 @@ def main(): model_use_feature_graph = bool(cfg.get("model_use_feature_graph", False)) feature_graph_scale = float(cfg.get("feature_graph_scale", 0.1)) feature_graph_dropout = float(cfg.get("feature_graph_dropout", 0.0)) + use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False)) + temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256)) + temporal_num_layers = int(cfg.get("temporal_num_layers", 1)) + temporal_dropout = float(cfg.get("temporal_dropout", 0.0)) split = load_split(str(SPLIT_PATH)) time_col = split.get("time_column", "time") @@ -98,6 +103,20 @@ def main(): model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True)) model.eval() + temporal_model = None + if use_temporal_stage1: + temporal_model = TemporalGRUGenerator( + input_dim=len(cont_cols), + hidden_dim=temporal_hidden_dim, + num_layers=temporal_num_layers, + dropout=temporal_dropout, + ).to(DEVICE) + temporal_path = BASE_DIR / "results" / "temporal.pt" + if not temporal_path.exists(): + raise SystemExit(f"missing temporal model file: {temporal_path}") + temporal_model.load_state_dict(torch.load(str(temporal_path), map_location=DEVICE, weights_only=True)) + temporal_model.eval() + betas = cosine_beta_schedule(timesteps).to(DEVICE) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) @@ -116,6 +135,10 @@ def main(): raise SystemExit("use_condition enabled but no files matched data_glob") cond = torch.randint(0, cond_vocab_size, (batch_size,), device=DEVICE, dtype=torch.long) + trend = None + if temporal_model is not None: + trend = temporal_model.generate(batch_size, seq_len, DEVICE) + for t in reversed(range(timesteps)): t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long) eps_pred, logits = model(x_cont, x_disc, t_batch, cond) @@ -155,6 +178,8 @@ def main(): sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(BATCH_SIZE, SEQ_LEN) x_disc[:, :, i][mask] = sampled[mask] + if trend is not None: + x_cont = x_cont + trend print("sampled_cont_shape", tuple(x_cont.shape)) print("sampled_disc_shape", tuple(x_disc.shape)) diff --git a/example/summary_metrics.py b/example/summary_metrics.py new file mode 100644 index 0000000..e09c804 --- /dev/null +++ b/example/summary_metrics.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +"""Print average metrics from eval.json for quick tracking.""" + +import json +from datetime import datetime +from pathlib import Path + + +def mean(values): + return sum(values) / len(values) if values else None + + +def main(): + base_dir = Path(__file__).resolve().parent + eval_path = base_dir / "results" / "eval.json" + if not eval_path.exists(): + raise SystemExit(f"missing eval.json: {eval_path}") + + obj = json.loads(eval_path.read_text(encoding="utf-8")) + ks = list(obj.get("continuous_ks", {}).values()) + jsd = list(obj.get("discrete_jsd", {}).values()) + lag = list(obj.get("continuous_lag1_diff", {}).values()) + + avg_ks = mean(ks) + avg_jsd = mean(jsd) + avg_lag1 = mean(lag) + + print("avg_ks", avg_ks) + print("avg_jsd", avg_jsd) + print("avg_lag1_diff", avg_lag1) + + history_path = base_dir / "results" / "metrics_history.csv" + if not history_path.exists(): + history_path.write_text("timestamp,avg_ks,avg_jsd,avg_lag1_diff\n", encoding="utf-8") + with history_path.open("a", encoding="utf-8") as f: + f.write(f"{datetime.utcnow().isoformat()},{avg_ks},{avg_jsd},{avg_lag1}\n") + + +if __name__ == "__main__": + main() diff --git a/example/train.py b/example/train.py index 025f3d5..c7d55a2 100755 --- a/example/train.py +++ b/example/train.py @@ -14,6 +14,7 @@ import torch.nn.functional as F from data_utils import load_split, windowed_batches from hybrid_diffusion import ( HybridDiffusionModel, + TemporalGRUGenerator, cosine_beta_schedule, q_sample_continuous, q_sample_discrete, @@ -61,6 +62,11 @@ DEFAULTS = { "model_use_feature_graph": True, "feature_graph_scale": 0.1, "feature_graph_dropout": 0.0, + "use_temporal_stage1": True, + "temporal_hidden_dim": 256, + "temporal_num_layers": 1, + "temporal_dropout": 0.0, + "temporal_loss_weight": 1.0, "disc_mask_scale": 0.9, "shuffle_buffer": 256, "cont_loss_weighting": "none", # none | inv_std @@ -204,7 +210,19 @@ def main(): use_tanh_eps=bool(config.get("use_tanh_eps", False)), eps_scale=float(config.get("eps_scale", 1.0)), ).to(device) + temporal_model = None + if bool(config.get("use_temporal_stage1", False)): + temporal_model = TemporalGRUGenerator( + input_dim=len(cont_cols), + hidden_dim=int(config.get("temporal_hidden_dim", 256)), + num_layers=int(config.get("temporal_num_layers", 1)), + dropout=float(config.get("temporal_dropout", 0.0)), + ).to(device) opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"])) + if temporal_model is not None: + opt_temporal = torch.optim.Adam(temporal_model.parameters(), lr=float(config["lr"])) + else: + opt_temporal = None ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None betas = cosine_beta_schedule(int(config["timesteps"])).to(device) @@ -250,10 +268,20 @@ def main(): x_cont = x_cont.to(device) x_disc = x_disc.to(device) + temporal_loss = None + x_cont_resid = x_cont + trend = None + if temporal_model is not None: + trend, pred_next = temporal_model.forward_teacher(x_cont) + temporal_loss = F.mse_loss(pred_next, x_cont[:, 1:, :]) + temporal_loss = temporal_loss * float(config.get("temporal_loss_weight", 1.0)) + trend = trend.detach() + x_cont_resid = x_cont - trend + bsz = x_cont.size(0) t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device) - x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod) + x_cont_t, noise = q_sample_continuous(x_cont_resid, t, alphas_cumprod) mask_tokens = torch.tensor(vocab_sizes, device=device) x_disc_t, mask = q_sample_discrete( @@ -268,13 +296,13 @@ def main(): cont_target = str(config.get("cont_target", "eps")) if cont_target == "x0": - x0_target = x_cont + x0_target = x_cont_resid if float(config.get("cont_clamp_x0", 0.0)) > 0: x0_target = torch.clamp(x0_target, -float(config["cont_clamp_x0"]), float(config["cont_clamp_x0"])) loss_base = (eps_pred - x0_target) ** 2 elif cont_target == "v": a_bar_t = alphas_cumprod[t].view(-1, 1, 1) - v_target = torch.sqrt(a_bar_t) * noise - torch.sqrt(1.0 - a_bar_t) * x_cont + v_target = torch.sqrt(a_bar_t) * noise - torch.sqrt(1.0 - a_bar_t) * x_cont_resid loss_base = (eps_pred - v_target) ** 2 else: loss_base = (eps_pred - noise) ** 2 @@ -311,7 +339,7 @@ def main(): q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95]) q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype) # Use normalized space for stable quantiles on x0. - x_real = x_cont + x_real = x_cont_resid a_bar_t = alphas_cumprod[t].view(-1, 1, 1) if cont_target == "x0": x_gen = eps_pred @@ -336,11 +364,18 @@ def main(): else: quantile_loss = torch.mean(torch.abs(q_diff)) loss = loss + q_weight * quantile_loss + opt.zero_grad() loss.backward() if float(config.get("grad_clip", 0.0)) > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"])) opt.step() + if opt_temporal is not None: + opt_temporal.zero_grad() + temporal_loss.backward() + if float(config.get("grad_clip", 0.0)) > 0: + torch.nn.utils.clip_grad_norm_(temporal_model.parameters(), float(config["grad_clip"])) + opt_temporal.step() if ema is not None: ema.update(model) @@ -375,11 +410,15 @@ def main(): } if ema is not None: ckpt["ema"] = ema.state_dict() + if temporal_model is not None: + ckpt["temporal"] = temporal_model.state_dict() torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt")) torch.save(model.state_dict(), os.path.join(out_dir, "model.pt")) if ema is not None: torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt")) + if temporal_model is not None: + torch.save(temporal_model.state_dict(), os.path.join(out_dir, "temporal.pt")) if __name__ == "__main__": diff --git a/report.md b/report.md new file mode 100644 index 0000000..77f176f --- /dev/null +++ b/report.md @@ -0,0 +1,365 @@ +# Hybrid Diffusion for ICS Traffic (HAI 21.03) — Project Report +# 工业控制系统流量混合扩散生成(HAI 21.03)— 项目报告 + +## 1. Project Goal / 项目目标 +Build a **hybrid diffusion-based generator** for industrial control system (ICS) traffic features, targeting **mixed continuous + discrete** feature sequences. The output is **feature-level sequences**, not raw packets. The generator should preserve: +- **Distributional fidelity** (continuous value ranges and discrete frequencies) +- **Temporal consistency** (time correlation and sequence structure) +- **Protocol/field consistency** (for discrete fields) + +构建一个用于工业控制系统(ICS)流量特征的**混合扩散生成模型**,面向**连续+离散混合特征序列**。输出为**特征级序列**而非原始报文。生成结果需要同时保持: +- **分布一致性**(连续值范围与离散取值频率) +- **时序一致性**(时间相关性与序列结构) +- **字段/协议一致性**(离散字段的逻辑一致) + +This project is aligned with the STOUTER idea of **structure-aware diffusion** for spatiotemporal data, but applied to **ICS feature sequences** rather than cellular traffic. + +本项目呼应 STOUTER 的**结构先验+扩散**思想,但应用于**ICS 特征序列**而非蜂窝流量。 + +--- + +## 2. Data and Scope / 数据与范围 +**Dataset used in the current implementation:** HAI 21.03 (CSV feature traces) + +**当前实现使用的数据集:** HAI 21.03(CSV 特征序列) + +**Data location (default in config):** +- `dataset/hai/hai-21.03/train*.csv.gz` + +**数据位置(config 默认):** +- `dataset/hai/hai-21.03/train*.csv.gz` + +**Feature split (fixed schema):** +- Defined in `example/feature_split.json` +- **Continuous features:** sensor/process values +- **Discrete features:** binary/low-cardinality status/flag fields +- `time` column is excluded from modeling + +**特征拆分(固定 schema):** +- `example/feature_split.json` +- **连续特征:** 传感器/过程值 +- **离散特征:** 二值/低基数状态字段 +- `time` 列不参与训练 + +--- + +## 3. End-to-End Pipeline / 端到端流程 + +**One command pipeline:** +``` +python example/run_all.py --device cuda +``` + +**一键流程:** +``` +python example/run_all.py --device cuda +``` + +### Pipeline stages / 流程阶段 +1) **Prepare data** (`example/prepare_data.py`) +2) **Train model** (`example/train.py`) +3) **Generate samples** (`example/export_samples.py`) +4) **Evaluate** (`example/evaluate_generated.py`) +5) **Summarize metrics** (`example/summary_metrics.py`) + +1) **数据准备**(统计量与词表) +2) **训练模型** +3) **生成样本并导出** +4) **评估指标** +5) **汇总指标** + +--- + +## 4. Technical Architecture / 技术架构 + +### 4.1 Hybrid Diffusion Model (Core) / 混合扩散模型(核心) +Defined in `example/hybrid_diffusion.py`. + +**Key components:** +- **Continuous branch**: Gaussian diffusion (DDPM style) +- **Discrete branch**: Mask diffusion for categorical tokens +- **Shared backbone**: GRU + residual MLP + LayerNorm +- **Embedding inputs**: + - continuous projection + - discrete embeddings per column + - time embedding (sinusoidal) + - positional embedding (sequence index) + - optional condition embedding (`file_id`) + +**Outputs:** +- Continuous head: predicts target (`eps`, `x0`, or `v`) +- Discrete heads: predict logits for each discrete column + +**核心组成:** +- **连续分支:** 高斯扩散(DDPM) +- **离散分支:** Mask 扩散 +- **共享主干:** GRU + 残差 MLP + LayerNorm +- **输入嵌入:** + - 连续投影 + - 离散字段嵌入 + - 时间嵌入(正弦) + - 位置嵌入(序列索引) + - 条件嵌入(可选,`file_id`) + +**输出:** +- 连续 head:预测 `eps/x0/v` +- 离散 head:各字段 logits + +--- + +### 4.2 Feature Graph Mixer (Structure Prior) / 特征图混合器(结构先验) +Implemented in `example/hybrid_diffusion.py` as `FeatureGraphMixer`. + +Purpose: inject **learnable feature-dependency prior** without dataset-specific hardcoding. + +**Mechanism:** +- Learns a dense feature relation matrix `A` +- Applies: `x + x @ A` +- Symmetric stabilizing constraint: `(A + A^T)/2` +- Controlled by scale and dropout + +**Config:** +``` +"model_use_feature_graph": true, +"feature_graph_scale": 0.1, +"feature_graph_dropout": 0.0 +``` + +**目的:**在不写死特定数据集关系的情况下,引入**可学习特征依赖先验**。 + +**机制:** +- 学习稠密关系矩阵 `A` +- 特征混合:`x + x @ A` +- 对称化稳定:`(A + A^T)/2` +- 通过 scale/dropout 控制强度 + +--- + +### 4.3 Two-Stage Temporal Backbone / 两阶段时序骨干 +Stage-1 uses a **GRU temporal generator** to model sequence trend in normalized space. Stage-2 diffusion then models the **residual** (x − trend). This decouples temporal consistency from distribution alignment. + +第一阶段使用 **GRU 时序生成器**在归一化空间建模序列趋势;第二阶段扩散模型学习**残差**(x − trend),实现时序一致性与分布对齐的解耦。 + +--- + +## 5. Diffusion Formulations / 扩散建模形式 + +### 5.1 Continuous Diffusion / 连续扩散 +Forward process: +``` +x_t = sqrt(a_bar_t) * x_0 + sqrt(1 - a_bar_t) * eps +``` + +Targets supported: +- **eps prediction** (standard DDPM) +- **x0 prediction** (direct reconstruction) +- **v prediction** (v = sqrt(a_bar)*eps − sqrt(1-a_bar)*x0) + +Current config default: +``` +"cont_target": "v" +``` + +Sampling uses the target to reconstruct `eps` and apply standard DDPM reverse update. + +**前向扩散:**如上公式。 + +**支持的目标:** +- `eps`(噪声预测) +- `x0`(原样本预测) +- `v`(v‑prediction) + +**当前默认:**`cont_target = v` + +**采样:**根据目标反解 `eps` 再执行标准 DDPM 反向步骤。 + +--- + +### 5.2 Discrete Diffusion (Mask) / 离散扩散(Mask) +Forward process: replace tokens with `[MASK]` using cosine schedule: +``` +p(t) = 0.5 * (1 - cos(pi * t / T)) +``` +Optional scale: `disc_mask_scale` + +Reverse process: cross-entropy on masked positions only. + +**前向:**按 cosine schedule 进行 Mask。 +**反向:**仅在 mask 位置计算交叉熵。 + +--- + +## 6. Loss Design (Current) / 当前损失设计 +Total loss: +``` +L = λ * L_cont + (1 − λ) * L_disc + w_q * L_quantile +``` + +### 6.1 Continuous Loss / 连续损失 +Depending on `cont_target`: +- eps target: MSE(eps_pred, eps) +- x0 target: MSE(x0_pred, x0) +- v target: MSE(v_pred, v_target) + +Optional inverse-variance weighting: +``` +cont_loss_weighting = "inv_std" +``` + +### 6.2 Discrete Loss / 离散损失 +Cross-entropy on masked positions only. + +### 6.3 Quantile Loss (Distribution Alignment) / 分位数损失(分布对齐) +Added to improve KS (distribution shape alignment): +- Compute quantiles on generated vs real x0 +- Loss = Huber or L1 difference on quantiles + +Stabilization: +``` +quantile_loss_warmup_steps +quantile_loss_clip +quantile_loss_huber_delta +``` + +--- + +## 7. Training Strategy / 训练策略 +Defined in `example/train.py`. + +**Key techniques:** +- EMA of model weights +- Gradient clipping +- Shuffle buffer to reduce batch bias +- Optional feature graph prior +- Quantile loss warmup for stability +- Optional stage-1 temporal GRU (trend) + residual diffusion + +**Config highlights (example/config.json):** +``` +timesteps: 600 +batch_size: 128 +seq_len: 128 +epochs: 10 +max_batches: 4000 +lambda: 0.7 +cont_target: "v" +quantile_loss_weight: 0.1 +model_use_feature_graph: true +use_temporal_stage1: true +``` + +--- + +## 8. Sampling & Export / 采样与导出 +Defined in: +- `example/sample.py` +- `example/export_samples.py` + +**Export steps:** +- Reverse diffusion with conditional sampling +- Reverse normalize continuous values +- Clamp to observed min/max +- Restore discrete tokens from vocab +- Write to CSV + +--- + +## 9. Evaluation Metrics / 评估指标 +Implemented in `example/evaluate_generated.py`. + +### Continuous Metrics / 连续指标 +- **KS statistic** (distribution similarity per feature) +- **Quantile errors** (q05/q25/q50/q75/q95) +- **Lag‑1 correlation diff** (temporal structure) + +### Discrete Metrics / 离散指标 +- **JSD** over token frequency distribution +- **Invalid token counts** + +### Summary Metrics / 汇总指标 +Auto-logged in: +- `example/results/metrics_history.csv` +- via `example/summary_metrics.py` + +--- + +## 10. Automation / 自动化 + +### One‑click pipeline / 一键流程 +``` +python example/run_all.py --device cuda +``` + +### Metrics logging / 指标记录 +Each run appends: +``` +timestamp,avg_ks,avg_jsd,avg_lag1_diff +``` + +--- + +## 11. Key Engineering Decisions / 关键工程决策 + +### 11.1 Mixed-Type Diffusion / 混合类型扩散 +Continuous + discrete handled separately to respect data types. + +### 11.2 Structure Prior / 结构先验 +Learnable feature graph added to encode implicit dependencies. + +### 11.3 v‑prediction +Chosen to stabilize training and improve convergence in diffusion. + +### 11.4 Distribution Alignment / 分布对齐 +Quantile loss introduced to directly reduce KS. + +--- + +## 12. Known Issues / Current Limitations / 已知问题与当前局限 +- **KS remains high** in many experiments, meaning continuous distributions are still misaligned. +- **Lag‑1 may degrade** when quantile loss is too strong. +- **Loss spikes** observed when quantile loss is unstable (mitigated with warmup + clip + Huber). + +**当前问题:** +- KS 高,说明连续分布仍未对齐 +- 分位数损失过强时会损害时序相关性 +- 分位数损失不稳定时会出现 loss 爆炸(已引入 warmup/clip/Huber) + +--- + +## 13. Suggested Next Steps (Research Roadmap) / 下一步建议(研究路线) +1) **SNR-weighted loss** (improve stability across timesteps) +2) **Two-stage training** (distribution first, temporal consistency second) +3) **Upgrade discrete diffusion** (D3PM-style transitions) +4) **Structured conditioning** (state/phase conditioning) +5) **Graph-based priors** (explicit feature/plant dependency graphs) + +--- + +## 14. Code Map (Key Files) / 代码索引(关键文件) + +**Core model** +- `example/hybrid_diffusion.py` + +**Training** +- `example/train.py` + +**Sampling & export** +- `example/sample.py` +- `example/export_samples.py` + +**Pipeline** +- `example/run_all.py` + +**Evaluation** +- `example/evaluate_generated.py` +- `example/summary_metrics.py` + +**Configs** +- `example/config.json` + +--- + +## 15. Summary / 总结 +This project implements a **hybrid diffusion model for ICS traffic features**, combining continuous Gaussian diffusion with discrete mask diffusion, enhanced with a **learnable feature-graph prior**. The system includes a full pipeline for preparation, training, sampling, exporting, and evaluation. Key research challenges remain in **distribution alignment (KS)** and **joint optimization of distribution fidelity vs temporal consistency**, motivating future improvements such as SNR-weighted loss, staged training, and stronger structural priors. + +本项目实现了用于 ICS 流量特征的**混合扩散模型**,将连续高斯扩散与离散 Mask 扩散结合,并引入**可学习特征图先验**。系统包含完整的数据准备、训练、采样、导出与评估流程。当前研究挑战集中在**连续分布对齐(KS)**与**分布/时序一致性之间的权衡**,后续可通过 SNR‑weighted loss、分阶段训练与更强结构先验继续改进。