10 KiB
mask-ddpm 项目说明书(完整详细版)
本文档是“说明书级别”的完整描述,面向首次接触项目的同学。 目标是让不了解扩散/时序建模的人也能理解:项目是什么、怎么跑、每个文件干什么、每一步在训练什么、为什么这么设计。
适用范围:当前仓库代码(以
example/config.json为主配置)。
目录
- 项目目标与研究问题
- 数据与特征结构
- 预处理与统计文件
- 模型总体架构
- 训练流程(逐步骤)
- 采样与导出流程
- 评估体系与指标
- 诊断工具与常用脚本
- Type‑aware(按类型分治)设计
- 一键运行与常见命令
- 输出文件说明
- 当前配置与关键超参
- 常见问题与慢的原因
- 已知限制与后续方向
- 文件树(精简版)
- 文件职责(逐文件说明)
1. 项目目标与研究问题
本项目目标:生成工业控制系统(ICS)多变量时序数据,满足以下三点:
- 分布一致性:每个变量的统计分布接近真实(用 KS 衡量)
- 时序一致性:序列结构合理,lag‑1 相关性、趋势符合真实
- 离散合法性:离散变量(状态/模式)必须是合法 token 且分布合理(JSD)
核心难点:
- 时序结构和分布对齐经常相互冲突
- 真实数据包含“程序驱动/事件驱动”的变量,难以用纯 DDPM 学好
2. 数据与特征结构
数据来源:HAI train*.csv.gz(多文件)
特征拆分(见 example/feature_split.json):
continuous:连续变量(传感器/执行器)discrete:离散变量(状态/模式)time_column:时间列(不参与训练)
3. 预处理与统计文件
脚本:example/prepare_data.py
3.1 连续变量
- 计算 mean/std
- 若开启
use_quantile_transform:计算分位数表(CDF) - 输出:
example/results/cont_stats.json
3.2 离散变量
- 统计 vocab
- 输出:
example/results/disc_vocab.json
3.3 数据工具
example/data_utils.py 提供:
- 标准化/反标准化
- 分位数变换/逆变换
- 可选后校准(quantile calibration)
4. 模型总体架构
本项目采用 两阶段 + 混合扩散 架构:
4.1 Stage‑1 Temporal GRU
- 目的:学习序列趋势、时序结构
- 输入:连续变量序列
- 输出:trend(趋势序列)
4.2 Stage‑2 Hybrid Diffusion
- 目的:学习残差分布(把时序和分布解耦)
- 连续变量:Gaussian DDPM
- 离散变量:mask diffusion 分类 head
4.3 Backbone 选择
- 当前配置:
backbone_type = transformer - 可选:GRU(更省显存更稳定)
5. 训练流程(逐步骤)
脚本:example/train.py
Step 1:Temporal 训练
- 输入:连续序列
- GRU teacher‑forcing 预测下一步
- Loss:MSE
- 输出:
temporal.pt
Step 2:Diffusion 训练
- 计算残差:
x_resid = x_cont - trend - 采样时间步 t
- 连续:加噪
- 离散:mask token
- 模型预测 eps / logits
Loss 设计
- Continuous:MSE(eps 或 x0)
- Discrete:Cross Entropy(mask 部分)
- 总损失:
loss = λ * loss_cont + (1-λ) * loss_disc - 可选加权:
- inverse‑std
- SNR‑weighted
- quantile loss
- residual stat loss
6. 采样与导出流程
脚本:example/export_samples.py
流程:
- 初始化噪声(连续)
- 初始化 mask(离散)
- 反扩散 t=T..0
- 加回 trend
- 反变换(quantile/标准化)
- 合成 CSV
输出:example/results/generated.csv
7. 评估体系与指标
脚本:example/evaluate_generated.py
连续指标
- KS(tie‑aware)
- quantile diff
- lag‑1 correlation
离散指标
- JSD
- invalid token 比例
Reference 读取
- 支持
train*.csv.gzglob - 自动汇总所有文件
8. 诊断工具与常用脚本
diagnose_ks.py:CDF 可视化ranked_ks.py:KS 贡献排序filtered_metrics.py:过滤异常特征后的 KSprogram_stats.py:Type1 统计controller_stats.py:Type2 统计actuator_stats.py:Type3 统计pv_stats.py:Type4 统计aux_stats.py:Type6 统计
9. Type‑aware 设计(按类型分治)
在真实 ICS 中,部分变量很难用 DDPM 学到,所以做类型划分:
- Type1:setpoint/demand(调度驱动)
- Type2:controller outputs
- Type3:actuator positions
- Type4:PV sensors
- Type5:derived tags
- Type6:aux/coupling
脚本:example/postprocess_types.py
当前实现是 KS‑only baseline:
- Type1/2/3/5/6 → 经验重采样
- Type4 → 仍用 diffusion
用途:
- 快速诊断“KS 最优可达上界”
- 不保证联合分布真实性
输出:example/results/generated_post.csv
10. 一键运行与常见命令
全流程(推荐)
python example/run_all.py --device cuda --config example/config.json
只评估不训练
python example/run_all.py --skip-prepare --skip-train --skip-export
只训练不评估
python example/run_all.py --skip-eval --skip-postprocess --skip-post-eval --skip-diagnostics
11. 输出文件说明
generated.csv:原始 diffusion 输出generated_post.csv:KS‑only 后处理输出eval.json:原始评估eval_post.json:后处理评估cont_stats.json/disc_vocab.json:统计文件*_stats.json:Type 统计报告
12. 当前配置(关键超参)
来自 example/config.json:
- backbone_type: transformer
- timesteps: 600
- seq_len: 96
- batch_size: 16
- cont_target: x0
- cont_loss_weighting: inv_std
- snr_weighted_loss: true
- quantile_loss_weight: 0.2
- use_quantile_transform: true
- cont_post_calibrate: true
- use_temporal_stage1: true
13. 为什么运行慢
- 两阶段训练(temporal + diffusion)
- 评估要读全量 train*.csv.gz
- run_all 默认跑所有诊断脚本
- timesteps / seq_len 大
14. 已知限制与后续方向
限制:
- Type1/2/3 仍主导 KS
- KS‑only baseline 会破坏联合分布
- 时序和分布存在 trade‑off
方向:
- 为 Type1/2/3 建条件模型
- Type4 增加 regime conditioning
- 联合指标(cross‑feature correlation)
15. 文件树(精简版)
mask-ddpm/
report.md
docs/
README.md
architecture.md
evaluation.md
decisions.md
experiments.md
ideas.md
example/
config.json
config_no_temporal.json
config_temporal_strong.json
feature_split.json
data_utils.py
prepare_data.py
hybrid_diffusion.py
train.py
sample.py
export_samples.py
evaluate_generated.py
run_all.py
run_compare.py
diagnose_ks.py
filtered_metrics.py
ranked_ks.py
program_stats.py
controller_stats.py
actuator_stats.py
pv_stats.py
aux_stats.py
postprocess_types.py
results/
16. 文件职责(逐文件说明)
prepare_data.py:统计连续/离散特征data_utils.py:预处理与变换函数hybrid_diffusion.py:模型主体(Temporal + Diffusion)train.py:两阶段训练export_samples.py:采样导出evaluate_generated.py:评估指标run_all.py:一键流程postprocess_types.py:Type‑aware KS‑only baselinediagnose_ks.py:CDF 诊断ranked_ks.py:KS 排序filtered_metrics.py:过滤 KS
结束
如果你需要更“论文式”的版本(加入公式、伪代码、实验表格),可以继续追加。
附录 A:公式汇总(论文可用版)
说明:本附录包含 代码中已有 的核心公式,以及 合理的扩展公式(可作为方法增强/未来工作)。公式写法尽量简洁,但强调“可解释 + 可复现”。
A.1 现有实现可对齐的核心公式
(1) 连续扩散(残差 DDPM)
[ x_t = \sqrt{\bar{\alpha}_t},x_0 + \sqrt{1-\bar{\alpha}_t},\epsilon,\quad \epsilon\sim\mathcal{N}(0,I) ]
- 解释:对残差进行标准 DDPM 加噪
(2) 连续分支损失(x0 或 eps 预测)
[ \mathcal{L}{cont} = \begin{cases} | \hat{\epsilon}\theta - \epsilon |^2 & \text{if target=eps}\ | \hat{x}_0 - x_0 |^2 & \text{if target=x0} \end{cases} ]
- 解释:当前配置使用
cont_target=x0或eps
(3) SNR 加权(当前代码可选)
[ \mathcal{L}_{snr} = \frac{\text{SNR}_t}{\text{SNR}t+\gamma},\mathcal{L}{cont} ]
- 解释:高噪声阶段减小权重
(4) 离散 Mask‑Diffusion 交叉熵
[ \mathcal{L}{disc} = \frac{1}{|\mathcal{M}|}\sum{(i,t)\in\mathcal{M}} \mathrm{CE}(\hat{p}{i,t}, y{i,t}) ]
- 解释:只对被 mask 的 token 计算
(5) 总损失
[ \mathcal{L} = \lambda \mathcal{L}{cont} + (1-\lambda)\mathcal{L}{disc} ]
- 解释:控制分布 vs 离散的权衡
(6) 分位数分布对齐(残差空间)
[ \mathcal{L}{Q} = \frac{1}{K}\sum{k=1}^{K}|Q_k(x_{real}) - Q_k(x_{gen})|_1 ]
- 解释:对齐分位数,改善 KS
A.2 合理扩展公式(可作为增强项 / 未来工作)
(7) 时序一致性正则(Lag‑1 约束)
[ \mathcal{L}{lag1} = |\rho_1(x{gen}) - \rho_1(x_{real})|_1 ]
- 解释:抑制时序退化(lag‑1 diff)
(8) 频谱一致性(Temporal PSD)
[ \mathcal{L}{spec} = |\log S(\omega; x{gen}) - \log S(\omega; x_{real})|_1 ]
- 解释:捕捉周期/扫描频率结构
(9) 多尺度 Wasserstein(分布 + 时序混合)
[ \mathcal{L}{MSW} = \sum{s\in\mathcal{S}} W_1\big(\phi_s(x_{gen}),\phi_s(x_{real})\big) ]
- 解释:多尺度对齐,兼顾分布与结构
(10) 条件一致性(Type‑aware)
[ \mathcal{L}{cond} = \mathbb{E}\big[|f{ctrl}(x_{gen}) - f_{ctrl}(x_{real})|_2^2\big] ]
- 解释:约束控制器/执行器在条件下合理
A.3 评估指标(论文描述版)
(11) 分布对齐(KS)
[ \text{KS}i = \sup_x |F^{(i)}{gen}(x)-F^{(i)}{real}(x)| ] [ \text{avg_KS} = \frac{1}{d}\sum{i=1}^{d}\text{KS}_i ]
(12) 离散一致性(JSD)
[ \text{JSD}(P,Q)=\tfrac12 KL(P|M)+\tfrac12 KL(Q|M) ]
(13) 时序偏差(Lag‑1 Diff)
[ \Delta_{lag1} = \frac{1}{d}\sum_i |\rho_1(x^{(i)}{gen})-\rho_1(x^{(i)}{real})| ]