425 lines
10 KiB
Markdown
425 lines
10 KiB
Markdown
# mask-ddpm 项目说明书(完整详细版)
|
||
|
||
> 本文档是“说明书级别”的完整描述,面向首次接触项目的同学。
|
||
> 目标是让**不了解扩散/时序建模的人**也能理解:项目是什么、怎么跑、每个文件干什么、每一步在训练什么、为什么这么设计。
|
||
>
|
||
> 适用范围:当前仓库代码(以 `example/config.json` 为主配置)。
|
||
|
||
---
|
||
|
||
## 目录
|
||
1. 项目目标与研究问题
|
||
2. 数据与特征结构
|
||
3. 预处理与统计文件
|
||
4. 模型总体架构
|
||
5. 训练流程(逐步骤)
|
||
6. 采样与导出流程
|
||
7. 评估体系与指标
|
||
8. 诊断工具与常用脚本
|
||
9. Type‑aware(按类型分治)设计
|
||
10. 一键运行与常见命令
|
||
11. 输出文件说明
|
||
12. 当前配置与关键超参
|
||
13. 常见问题与慢的原因
|
||
14. 已知限制与后续方向
|
||
15. 文件树(精简版)
|
||
16. 文件职责(逐文件说明)
|
||
|
||
---
|
||
|
||
## 1. 项目目标与研究问题
|
||
|
||
本项目目标:生成工业控制系统(ICS)多变量时序数据,满足以下三点:
|
||
|
||
- **分布一致性**:每个变量的统计分布接近真实(用 KS 衡量)
|
||
- **时序一致性**:序列结构合理,lag‑1 相关性、趋势符合真实
|
||
- **离散合法性**:离散变量(状态/模式)必须是合法 token 且分布合理(JSD)
|
||
|
||
核心难点:
|
||
- 时序结构和分布对齐经常相互冲突
|
||
- 真实数据包含“程序驱动/事件驱动”的变量,难以用纯 DDPM 学好
|
||
|
||
---
|
||
|
||
## 2. 数据与特征结构
|
||
|
||
**数据来源**:HAI `train*.csv.gz`(多文件)
|
||
|
||
**特征拆分**(见 `example/feature_split.json`):
|
||
- `continuous`:连续变量(传感器/执行器)
|
||
- `discrete`:离散变量(状态/模式)
|
||
- `time_column`:时间列(不参与训练)
|
||
|
||
---
|
||
|
||
## 3. 预处理与统计文件
|
||
|
||
脚本:`example/prepare_data.py`
|
||
|
||
### 3.1 连续变量
|
||
- 计算 mean/std
|
||
- 若开启 `use_quantile_transform`:计算分位数表(CDF)
|
||
- 输出:`example/results/cont_stats.json`
|
||
|
||
### 3.2 离散变量
|
||
- 统计 vocab
|
||
- 输出:`example/results/disc_vocab.json`
|
||
|
||
### 3.3 数据工具
|
||
`example/data_utils.py` 提供:
|
||
- 标准化/反标准化
|
||
- 分位数变换/逆变换
|
||
- 可选后校准(quantile calibration)
|
||
|
||
---
|
||
|
||
## 4. 模型总体架构
|
||
|
||
本项目采用 **两阶段 + 混合扩散** 架构:
|
||
|
||
### 4.1 Stage‑1 Temporal GRU
|
||
- 目的:学习序列趋势、时序结构
|
||
- 输入:连续变量序列
|
||
- 输出:trend(趋势序列)
|
||
|
||
### 4.2 Stage‑2 Hybrid Diffusion
|
||
- 目的:学习残差分布(把时序和分布解耦)
|
||
- 连续变量:Gaussian DDPM
|
||
- 离散变量:mask diffusion 分类 head
|
||
|
||
### 4.3 Backbone 选择
|
||
- 当前配置:`backbone_type = transformer`
|
||
- 可选:GRU(更省显存更稳定)
|
||
|
||
---
|
||
|
||
## 5. 训练流程(逐步骤)
|
||
|
||
脚本:`example/train.py`
|
||
|
||
### Step 1:Temporal 训练
|
||
- 输入:连续序列
|
||
- GRU teacher‑forcing 预测下一步
|
||
- Loss:MSE
|
||
- 输出:`temporal.pt`
|
||
|
||
### Step 2:Diffusion 训练
|
||
- 计算残差:`x_resid = x_cont - trend`
|
||
- 采样时间步 t
|
||
- 连续:加噪
|
||
- 离散:mask token
|
||
- 模型预测 eps / logits
|
||
|
||
### Loss 设计
|
||
- Continuous:MSE(eps 或 x0)
|
||
- Discrete:Cross Entropy(mask 部分)
|
||
- 总损失:`loss = λ * loss_cont + (1-λ) * loss_disc`
|
||
- 可选加权:
|
||
- inverse‑std
|
||
- SNR‑weighted
|
||
- quantile loss
|
||
- residual stat loss
|
||
|
||
---
|
||
|
||
## 6. 采样与导出流程
|
||
|
||
脚本:`example/export_samples.py`
|
||
|
||
流程:
|
||
1) 初始化噪声(连续)
|
||
2) 初始化 mask(离散)
|
||
3) 反扩散 t=T..0
|
||
4) 加回 trend
|
||
5) 反变换(quantile/标准化)
|
||
6) 合成 CSV
|
||
|
||
输出:`example/results/generated.csv`
|
||
|
||
---
|
||
|
||
## 7. 评估体系与指标
|
||
|
||
脚本:`example/evaluate_generated.py`
|
||
|
||
### 连续指标
|
||
- **KS(tie‑aware)**
|
||
- quantile diff
|
||
- lag‑1 correlation
|
||
|
||
### 离散指标
|
||
- JSD
|
||
- invalid token 比例
|
||
|
||
### Reference 读取
|
||
- 支持 `train*.csv.gz` glob
|
||
- 自动汇总所有文件
|
||
|
||
---
|
||
|
||
## 8. 诊断工具与常用脚本
|
||
|
||
- `diagnose_ks.py`:CDF 可视化
|
||
- `ranked_ks.py`:KS 贡献排序
|
||
- `filtered_metrics.py`:过滤异常特征后的 KS
|
||
- `program_stats.py`:Type1 统计
|
||
- `controller_stats.py`:Type2 统计
|
||
- `actuator_stats.py`:Type3 统计
|
||
- `pv_stats.py`:Type4 统计
|
||
- `aux_stats.py`:Type6 统计
|
||
|
||
---
|
||
|
||
## 9. Type‑aware 设计(按类型分治)
|
||
|
||
在真实 ICS 中,部分变量很难用 DDPM 学到,所以做类型划分:
|
||
|
||
- **Type1**:setpoint/demand(调度驱动)
|
||
- **Type2**:controller outputs
|
||
- **Type3**:actuator positions
|
||
- **Type4**:PV sensors
|
||
- **Type5**:derived tags
|
||
- **Type6**:aux/coupling
|
||
|
||
脚本:`example/postprocess_types.py`
|
||
|
||
当前实现是 **KS‑only baseline**:
|
||
- Type1/2/3/5/6 → 经验重采样
|
||
- Type4 → 仍用 diffusion
|
||
|
||
用途:
|
||
- 快速诊断“KS 最优可达上界”
|
||
- 不保证联合分布真实性
|
||
|
||
输出:`example/results/generated_post.csv`
|
||
|
||
---
|
||
|
||
## 10. 一键运行与常见命令
|
||
|
||
### 全流程(推荐)
|
||
```bash
|
||
python example/run_all.py --device cuda --config example/config.json
|
||
```
|
||
|
||
### 只评估不训练
|
||
```bash
|
||
python example/run_all.py --skip-prepare --skip-train --skip-export
|
||
```
|
||
|
||
### 只训练不评估
|
||
```bash
|
||
python example/run_all.py --skip-eval --skip-postprocess --skip-post-eval --skip-diagnostics
|
||
```
|
||
|
||
---
|
||
|
||
## 11. 输出文件说明
|
||
|
||
- `generated.csv`:原始 diffusion 输出
|
||
- `generated_post.csv`:KS‑only 后处理输出
|
||
- `eval.json`:原始评估
|
||
- `eval_post.json`:后处理评估
|
||
- `cont_stats.json` / `disc_vocab.json`:统计文件
|
||
- `*_stats.json`:Type 统计报告
|
||
|
||
---
|
||
|
||
## 12. 当前配置(关键超参)
|
||
|
||
来自 `example/config.json`:
|
||
- backbone_type: **transformer**
|
||
- timesteps: 600
|
||
- seq_len: 96
|
||
- batch_size: 16
|
||
- cont_target: x0
|
||
- cont_loss_weighting: inv_std
|
||
- snr_weighted_loss: true
|
||
- quantile_loss_weight: 0.2
|
||
- use_quantile_transform: true
|
||
- cont_post_calibrate: true
|
||
- use_temporal_stage1: true
|
||
|
||
---
|
||
|
||
## 13. 为什么运行慢
|
||
|
||
1) 两阶段训练(temporal + diffusion)
|
||
2) 评估要读全量 train*.csv.gz
|
||
3) run_all 默认跑所有诊断脚本
|
||
4) timesteps / seq_len 大
|
||
|
||
---
|
||
|
||
## 14. 已知限制与后续方向
|
||
|
||
限制:
|
||
- Type1/2/3 仍主导 KS
|
||
- KS‑only baseline 会破坏联合分布
|
||
- 时序和分布存在 trade‑off
|
||
|
||
方向:
|
||
- 为 Type1/2/3 建条件模型
|
||
- Type4 增加 regime conditioning
|
||
- 联合指标(cross‑feature correlation)
|
||
|
||
---
|
||
|
||
## 15. 文件树(精简版)
|
||
|
||
```
|
||
mask-ddpm/
|
||
report.md
|
||
docs/
|
||
README.md
|
||
architecture.md
|
||
evaluation.md
|
||
decisions.md
|
||
experiments.md
|
||
ideas.md
|
||
example/
|
||
config.json
|
||
config_no_temporal.json
|
||
config_temporal_strong.json
|
||
feature_split.json
|
||
data_utils.py
|
||
prepare_data.py
|
||
hybrid_diffusion.py
|
||
train.py
|
||
sample.py
|
||
export_samples.py
|
||
evaluate_generated.py
|
||
run_all.py
|
||
run_compare.py
|
||
diagnose_ks.py
|
||
filtered_metrics.py
|
||
ranked_ks.py
|
||
program_stats.py
|
||
controller_stats.py
|
||
actuator_stats.py
|
||
pv_stats.py
|
||
aux_stats.py
|
||
postprocess_types.py
|
||
results/
|
||
```
|
||
|
||
---
|
||
|
||
## 16. 文件职责(逐文件说明)
|
||
|
||
- `prepare_data.py`:统计连续/离散特征
|
||
- `data_utils.py`:预处理与变换函数
|
||
- `hybrid_diffusion.py`:模型主体(Temporal + Diffusion)
|
||
- `train.py`:两阶段训练
|
||
- `export_samples.py`:采样导出
|
||
- `evaluate_generated.py`:评估指标
|
||
- `run_all.py`:一键流程
|
||
- `postprocess_types.py`:Type‑aware KS‑only baseline
|
||
- `diagnose_ks.py`:CDF 诊断
|
||
- `ranked_ks.py`:KS 排序
|
||
- `filtered_metrics.py`:过滤 KS
|
||
|
||
---
|
||
|
||
# 结束
|
||
如果你需要更“论文式”的版本(加入公式、伪代码、实验表格),可以继续追加。
|
||
|
||
---
|
||
|
||
# 附录 A:公式汇总(论文可用版)
|
||
|
||
> 说明:本附录包含 **代码中已有** 的核心公式,以及 **合理的扩展公式**(可作为方法增强/未来工作)。公式写法尽量简洁,但强调“可解释 + 可复现”。
|
||
|
||
## A.1 现有实现可对齐的核心公式
|
||
|
||
### (1) 连续扩散(残差 DDPM)
|
||
\[
|
||
x_t = \sqrt{\bar{\alpha}_t}\,x_0 + \sqrt{1-\bar{\alpha}_t}\,\epsilon,\quad \epsilon\sim\mathcal{N}(0,I)
|
||
\]
|
||
- 解释:对残差进行标准 DDPM 加噪
|
||
|
||
### (2) 连续分支损失(x0 或 eps 预测)
|
||
\[
|
||
\mathcal{L}_{cont} =
|
||
\begin{cases}
|
||
\| \hat{\epsilon}_\theta - \epsilon \|^2 & \text{if target=eps}\\
|
||
\| \hat{x}_0 - x_0 \|^2 & \text{if target=x0}
|
||
\end{cases}
|
||
\]
|
||
- 解释:当前配置使用 `cont_target=x0` 或 `eps`
|
||
|
||
### (3) SNR 加权(当前代码可选)
|
||
\[
|
||
\mathcal{L}_{snr} = \frac{\text{SNR}_t}{\text{SNR}_t+\gamma}\,\mathcal{L}_{cont}
|
||
\]
|
||
- 解释:高噪声阶段减小权重
|
||
|
||
### (4) 离散 Mask‑Diffusion 交叉熵
|
||
\[
|
||
\mathcal{L}_{disc} = \frac{1}{|\mathcal{M}|}\sum_{(i,t)\in\mathcal{M}} \mathrm{CE}(\hat{p}_{i,t}, y_{i,t})
|
||
\]
|
||
- 解释:只对被 mask 的 token 计算
|
||
|
||
### (5) 总损失
|
||
\[
|
||
\mathcal{L} = \lambda \mathcal{L}_{cont} + (1-\lambda)\mathcal{L}_{disc}
|
||
\]
|
||
- 解释:控制分布 vs 离散的权衡
|
||
|
||
### (6) 分位数分布对齐(残差空间)
|
||
\[
|
||
\mathcal{L}_{Q} = \frac{1}{K}\sum_{k=1}^{K}\|Q_k(x_{real}) - Q_k(x_{gen})\|_1
|
||
\]
|
||
- 解释:对齐分位数,改善 KS
|
||
|
||
---
|
||
|
||
## A.2 合理扩展公式(可作为增强项 / 未来工作)
|
||
|
||
### (7) 时序一致性正则(Lag‑1 约束)
|
||
\[
|
||
\mathcal{L}_{lag1} = \|\rho_1(x_{gen}) - \rho_1(x_{real})\|_1
|
||
\]
|
||
- 解释:抑制时序退化(lag‑1 diff)
|
||
|
||
### (8) 频谱一致性(Temporal PSD)
|
||
\[
|
||
\mathcal{L}_{spec} = \|\log S(\omega; x_{gen}) - \log S(\omega; x_{real})\|_1
|
||
\]
|
||
- 解释:捕捉周期/扫描频率结构
|
||
|
||
### (9) 多尺度 Wasserstein(分布 + 时序混合)
|
||
\[
|
||
\mathcal{L}_{MSW} = \sum_{s\in\mathcal{S}} W_1\big(\phi_s(x_{gen}),\phi_s(x_{real})\big)
|
||
\]
|
||
- 解释:多尺度对齐,兼顾分布与结构
|
||
|
||
### (10) 条件一致性(Type‑aware)
|
||
\[
|
||
\mathcal{L}_{cond} = \mathbb{E}\big[\|f_{ctrl}(x_{gen}) - f_{ctrl}(x_{real})\|_2^2\big]
|
||
\]
|
||
- 解释:约束控制器/执行器在条件下合理
|
||
|
||
---
|
||
|
||
## A.3 评估指标(论文描述版)
|
||
|
||
### (11) 分布对齐(KS)
|
||
\[
|
||
\text{KS}_i = \sup_x |F^{(i)}_{gen}(x)-F^{(i)}_{real}(x)|
|
||
\]
|
||
\[
|
||
\text{avg\_KS} = \frac{1}{d}\sum_{i=1}^{d}\text{KS}_i
|
||
\]
|
||
|
||
### (12) 离散一致性(JSD)
|
||
\[
|
||
\text{JSD}(P,Q)=\tfrac12 KL(P\|M)+\tfrac12 KL(Q\|M)
|
||
\]
|
||
|
||
### (13) 时序偏差(Lag‑1 Diff)
|
||
\[
|
||
\Delta_{lag1} = \frac{1}{d}\sum_i |\rho_1(x^{(i)}_{gen})-\rho_1(x^{(i)}_{real})|
|
||
\]
|
||
|