# mask-ddpm 项目说明书(完整详细版) > 本文档是“说明书级别”的完整描述,面向首次接触项目的同学。 > 目标是让**不了解扩散/时序建模的人**也能理解:项目是什么、怎么跑、每个文件干什么、每一步在训练什么、为什么这么设计。 > > 适用范围:当前仓库代码(以 `example/config.json` 为主配置)。 --- ## 目录 1. 项目目标与研究问题 2. 数据与特征结构 3. 预处理与统计文件 4. 模型总体架构 5. 训练流程(逐步骤) 6. 采样与导出流程 7. 评估体系与指标 8. 诊断工具与常用脚本 9. Type‑aware(按类型分治)设计 10. 一键运行与常见命令 11. 输出文件说明 12. 当前配置与关键超参 13. 常见问题与慢的原因 14. 已知限制与后续方向 15. 文件树(精简版) 16. 文件职责(逐文件说明) --- ## 1. 项目目标与研究问题 本项目目标:生成工业控制系统(ICS)多变量时序数据,满足以下三点: - **分布一致性**:每个变量的统计分布接近真实(用 KS 衡量) - **时序一致性**:序列结构合理,lag‑1 相关性、趋势符合真实 - **离散合法性**:离散变量(状态/模式)必须是合法 token 且分布合理(JSD) 核心难点: - 时序结构和分布对齐经常相互冲突 - 真实数据包含“程序驱动/事件驱动”的变量,难以用纯 DDPM 学好 --- ## 2. 数据与特征结构 **数据来源**:HAI `train*.csv.gz`(多文件) **特征拆分**(见 `example/feature_split.json`): - `continuous`:连续变量(传感器/执行器) - `discrete`:离散变量(状态/模式) - `time_column`:时间列(不参与训练) --- ## 3. 预处理与统计文件 脚本:`example/prepare_data.py` ### 3.1 连续变量 - 计算 mean/std - 若开启 `use_quantile_transform`:计算分位数表(CDF) - 输出:`example/results/cont_stats.json` ### 3.2 离散变量 - 统计 vocab - 输出:`example/results/disc_vocab.json` ### 3.3 数据工具 `example/data_utils.py` 提供: - 标准化/反标准化 - 分位数变换/逆变换 - 可选后校准(quantile calibration) --- ## 4. 模型总体架构 本项目采用 **两阶段 + 混合扩散** 架构: ### 4.1 Stage‑1 Temporal GRU - 目的:学习序列趋势、时序结构 - 输入:连续变量序列 - 输出:trend(趋势序列) ### 4.2 Stage‑2 Hybrid Diffusion - 目的:学习残差分布(把时序和分布解耦) - 连续变量:Gaussian DDPM - 离散变量:mask diffusion 分类 head ### 4.3 Backbone 选择 - 当前配置:`backbone_type = transformer` - 可选:GRU(更省显存更稳定) --- ## 5. 训练流程(逐步骤) 脚本:`example/train.py` ### Step 1:Temporal 训练 - 输入:连续序列 - GRU teacher‑forcing 预测下一步 - Loss:MSE - 输出:`temporal.pt` ### Step 2:Diffusion 训练 - 计算残差:`x_resid = x_cont - trend` - 采样时间步 t - 连续:加噪 - 离散:mask token - 模型预测 eps / logits ### Loss 设计 - Continuous:MSE(eps 或 x0) - Discrete:Cross Entropy(mask 部分) - 总损失:`loss = λ * loss_cont + (1-λ) * loss_disc` - 可选加权: - inverse‑std - SNR‑weighted - quantile loss - residual stat loss --- ## 6. 采样与导出流程 脚本:`example/export_samples.py` 流程: 1) 初始化噪声(连续) 2) 初始化 mask(离散) 3) 反扩散 t=T..0 4) 加回 trend 5) 反变换(quantile/标准化) 6) 合成 CSV 输出:`example/results/generated.csv` --- ## 7. 评估体系与指标 脚本:`example/evaluate_generated.py` ### 连续指标 - **KS(tie‑aware)** - quantile diff - lag‑1 correlation ### 离散指标 - JSD - invalid token 比例 ### Reference 读取 - 支持 `train*.csv.gz` glob - 自动汇总所有文件 --- ## 8. 诊断工具与常用脚本 - `diagnose_ks.py`:CDF 可视化 - `ranked_ks.py`:KS 贡献排序 - `filtered_metrics.py`:过滤异常特征后的 KS - `program_stats.py`:Type1 统计 - `controller_stats.py`:Type2 统计 - `actuator_stats.py`:Type3 统计 - `pv_stats.py`:Type4 统计 - `aux_stats.py`:Type6 统计 --- ## 9. Type‑aware 设计(按类型分治) 在真实 ICS 中,部分变量很难用 DDPM 学到,所以做类型划分: - **Type1**:setpoint/demand(调度驱动) - **Type2**:controller outputs - **Type3**:actuator positions - **Type4**:PV sensors - **Type5**:derived tags - **Type6**:aux/coupling 脚本:`example/postprocess_types.py` 当前实现是 **KS‑only baseline**: - Type1/2/3/5/6 → 经验重采样 - Type4 → 仍用 diffusion 用途: - 快速诊断“KS 最优可达上界” - 不保证联合分布真实性 输出:`example/results/generated_post.csv` --- ## 10. 一键运行与常见命令 ### 全流程(推荐) ```bash python example/run_all.py --device cuda --config example/config.json ``` ### 只评估不训练 ```bash python example/run_all.py --skip-prepare --skip-train --skip-export ``` ### 只训练不评估 ```bash python example/run_all.py --skip-eval --skip-postprocess --skip-post-eval --skip-diagnostics ``` --- ## 11. 输出文件说明 - `generated.csv`:原始 diffusion 输出 - `generated_post.csv`:KS‑only 后处理输出 - `eval.json`:原始评估 - `eval_post.json`:后处理评估 - `cont_stats.json` / `disc_vocab.json`:统计文件 - `*_stats.json`:Type 统计报告 --- ## 12. 当前配置(关键超参) 来自 `example/config.json`: - backbone_type: **transformer** - timesteps: 600 - seq_len: 96 - batch_size: 16 - cont_target: x0 - cont_loss_weighting: inv_std - snr_weighted_loss: true - quantile_loss_weight: 0.2 - use_quantile_transform: true - cont_post_calibrate: true - use_temporal_stage1: true --- ## 13. 为什么运行慢 1) 两阶段训练(temporal + diffusion) 2) 评估要读全量 train*.csv.gz 3) run_all 默认跑所有诊断脚本 4) timesteps / seq_len 大 --- ## 14. 已知限制与后续方向 限制: - Type1/2/3 仍主导 KS - KS‑only baseline 会破坏联合分布 - 时序和分布存在 trade‑off 方向: - 为 Type1/2/3 建条件模型 - Type4 增加 regime conditioning - 联合指标(cross‑feature correlation) --- ## 15. 文件树(精简版) ``` mask-ddpm/ report.md docs/ README.md architecture.md evaluation.md decisions.md experiments.md ideas.md example/ config.json config_no_temporal.json config_temporal_strong.json feature_split.json data_utils.py prepare_data.py hybrid_diffusion.py train.py sample.py export_samples.py evaluate_generated.py run_all.py run_compare.py diagnose_ks.py filtered_metrics.py ranked_ks.py program_stats.py controller_stats.py actuator_stats.py pv_stats.py aux_stats.py postprocess_types.py results/ ``` --- ## 16. 文件职责(逐文件说明) - `prepare_data.py`:统计连续/离散特征 - `data_utils.py`:预处理与变换函数 - `hybrid_diffusion.py`:模型主体(Temporal + Diffusion) - `train.py`:两阶段训练 - `export_samples.py`:采样导出 - `evaluate_generated.py`:评估指标 - `run_all.py`:一键流程 - `postprocess_types.py`:Type‑aware KS‑only baseline - `diagnose_ks.py`:CDF 诊断 - `ranked_ks.py`:KS 排序 - `filtered_metrics.py`:过滤 KS --- # 结束 如果你需要更“论文式”的版本(加入公式、伪代码、实验表格),可以继续追加。 --- # 附录 A:公式汇总(论文可用版) > 说明:本附录包含 **代码中已有** 的核心公式,以及 **合理的扩展公式**(可作为方法增强/未来工作)。公式写法尽量简洁,但强调“可解释 + 可复现”。 ## A.1 现有实现可对齐的核心公式 ### (1) 连续扩散(残差 DDPM) \[ x_t = \sqrt{\bar{\alpha}_t}\,x_0 + \sqrt{1-\bar{\alpha}_t}\,\epsilon,\quad \epsilon\sim\mathcal{N}(0,I) \] - 解释:对残差进行标准 DDPM 加噪 ### (2) 连续分支损失(x0 或 eps 预测) \[ \mathcal{L}_{cont} = \begin{cases} \| \hat{\epsilon}_\theta - \epsilon \|^2 & \text{if target=eps}\\ \| \hat{x}_0 - x_0 \|^2 & \text{if target=x0} \end{cases} \] - 解释:当前配置使用 `cont_target=x0` 或 `eps` ### (3) SNR 加权(当前代码可选) \[ \mathcal{L}_{snr} = \frac{\text{SNR}_t}{\text{SNR}_t+\gamma}\,\mathcal{L}_{cont} \] - 解释:高噪声阶段减小权重 ### (4) 离散 Mask‑Diffusion 交叉熵 \[ \mathcal{L}_{disc} = \frac{1}{|\mathcal{M}|}\sum_{(i,t)\in\mathcal{M}} \mathrm{CE}(\hat{p}_{i,t}, y_{i,t}) \] - 解释:只对被 mask 的 token 计算 ### (5) 总损失 \[ \mathcal{L} = \lambda \mathcal{L}_{cont} + (1-\lambda)\mathcal{L}_{disc} \] - 解释:控制分布 vs 离散的权衡 ### (6) 分位数分布对齐(残差空间) \[ \mathcal{L}_{Q} = \frac{1}{K}\sum_{k=1}^{K}\|Q_k(x_{real}) - Q_k(x_{gen})\|_1 \] - 解释:对齐分位数,改善 KS --- ## A.2 合理扩展公式(可作为增强项 / 未来工作) ### (7) 时序一致性正则(Lag‑1 约束) \[ \mathcal{L}_{lag1} = \|\rho_1(x_{gen}) - \rho_1(x_{real})\|_1 \] - 解释:抑制时序退化(lag‑1 diff) ### (8) 频谱一致性(Temporal PSD) \[ \mathcal{L}_{spec} = \|\log S(\omega; x_{gen}) - \log S(\omega; x_{real})\|_1 \] - 解释:捕捉周期/扫描频率结构 ### (9) 多尺度 Wasserstein(分布 + 时序混合) \[ \mathcal{L}_{MSW} = \sum_{s\in\mathcal{S}} W_1\big(\phi_s(x_{gen}),\phi_s(x_{real})\big) \] - 解释:多尺度对齐,兼顾分布与结构 ### (10) 条件一致性(Type‑aware) \[ \mathcal{L}_{cond} = \mathbb{E}\big[\|f_{ctrl}(x_{gen}) - f_{ctrl}(x_{real})\|_2^2\big] \] - 解释:约束控制器/执行器在条件下合理 --- ## A.3 评估指标(论文描述版) ### (11) 分布对齐(KS) \[ \text{KS}_i = \sup_x |F^{(i)}_{gen}(x)-F^{(i)}_{real}(x)| \] \[ \text{avg\_KS} = \frac{1}{d}\sum_{i=1}^{d}\text{KS}_i \] ### (12) 离散一致性(JSD) \[ \text{JSD}(P,Q)=\tfrac12 KL(P\|M)+\tfrac12 KL(Q\|M) \] ### (13) 时序偏差(Lag‑1 Diff) \[ \Delta_{lag1} = \frac{1}{d}\sum_i |\rho_1(x^{(i)}_{gen})-\rho_1(x^{(i)}_{real})| \]