326 lines
7.4 KiB
Markdown
326 lines
7.4 KiB
Markdown
# mask-ddpm 项目说明书(完整详细版)
|
||
|
||
> 本文档是“说明书级别”的完整描述,面向首次接触项目的同学。
|
||
> 目标是让**不了解扩散/时序建模的人**也能理解:项目是什么、怎么跑、每个文件干什么、每一步在训练什么、为什么这么设计。
|
||
>
|
||
> 适用范围:当前仓库代码(以 `example/config.json` 为主配置)。
|
||
|
||
---
|
||
|
||
## 目录
|
||
1. 项目目标与研究问题
|
||
2. 数据与特征结构
|
||
3. 预处理与统计文件
|
||
4. 模型总体架构
|
||
5. 训练流程(逐步骤)
|
||
6. 采样与导出流程
|
||
7. 评估体系与指标
|
||
8. 诊断工具与常用脚本
|
||
9. Type‑aware(按类型分治)设计
|
||
10. 一键运行与常见命令
|
||
11. 输出文件说明
|
||
12. 当前配置与关键超参
|
||
13. 常见问题与慢的原因
|
||
14. 已知限制与后续方向
|
||
15. 文件树(精简版)
|
||
16. 文件职责(逐文件说明)
|
||
|
||
---
|
||
|
||
## 1. 项目目标与研究问题
|
||
|
||
本项目目标:生成工业控制系统(ICS)多变量时序数据,满足以下三点:
|
||
|
||
- **分布一致性**:每个变量的统计分布接近真实(用 KS 衡量)
|
||
- **时序一致性**:序列结构合理,lag‑1 相关性、趋势符合真实
|
||
- **离散合法性**:离散变量(状态/模式)必须是合法 token 且分布合理(JSD)
|
||
|
||
核心难点:
|
||
- 时序结构和分布对齐经常相互冲突
|
||
- 真实数据包含“程序驱动/事件驱动”的变量,难以用纯 DDPM 学好
|
||
|
||
---
|
||
|
||
## 2. 数据与特征结构
|
||
|
||
**数据来源**:HAI `train*.csv.gz`(多文件)
|
||
|
||
**特征拆分**(见 `example/feature_split.json`):
|
||
- `continuous`:连续变量(传感器/执行器)
|
||
- `discrete`:离散变量(状态/模式)
|
||
- `time_column`:时间列(不参与训练)
|
||
|
||
---
|
||
|
||
## 3. 预处理与统计文件
|
||
|
||
脚本:`example/prepare_data.py`
|
||
|
||
### 3.1 连续变量
|
||
- 计算 mean/std
|
||
- 若开启 `use_quantile_transform`:计算分位数表(CDF)
|
||
- 输出:`example/results/cont_stats.json`
|
||
|
||
### 3.2 离散变量
|
||
- 统计 vocab
|
||
- 输出:`example/results/disc_vocab.json`
|
||
|
||
### 3.3 数据工具
|
||
`example/data_utils.py` 提供:
|
||
- 标准化/反标准化
|
||
- 分位数变换/逆变换
|
||
- 可选后校准(quantile calibration)
|
||
|
||
---
|
||
|
||
## 4. 模型总体架构
|
||
|
||
本项目采用 **两阶段 + 混合扩散** 架构:
|
||
|
||
### 4.1 Stage‑1 Temporal GRU
|
||
- 目的:学习序列趋势、时序结构
|
||
- 输入:连续变量序列
|
||
- 输出:trend(趋势序列)
|
||
|
||
### 4.2 Stage‑2 Hybrid Diffusion
|
||
- 目的:学习残差分布(把时序和分布解耦)
|
||
- 连续变量:Gaussian DDPM
|
||
- 离散变量:mask diffusion 分类 head
|
||
|
||
### 4.3 Backbone 选择
|
||
- 当前配置:`backbone_type = transformer`
|
||
- 可选:GRU(更省显存更稳定)
|
||
|
||
---
|
||
|
||
## 5. 训练流程(逐步骤)
|
||
|
||
脚本:`example/train.py`
|
||
|
||
### Step 1:Temporal 训练
|
||
- 输入:连续序列
|
||
- GRU teacher‑forcing 预测下一步
|
||
- Loss:MSE
|
||
- 输出:`temporal.pt`
|
||
|
||
### Step 2:Diffusion 训练
|
||
- 计算残差:`x_resid = x_cont - trend`
|
||
- 采样时间步 t
|
||
- 连续:加噪
|
||
- 离散:mask token
|
||
- 模型预测 eps / logits
|
||
|
||
### Loss 设计
|
||
- Continuous:MSE(eps 或 x0)
|
||
- Discrete:Cross Entropy(mask 部分)
|
||
- 总损失:`loss = λ * loss_cont + (1-λ) * loss_disc`
|
||
- 可选加权:
|
||
- inverse‑std
|
||
- SNR‑weighted
|
||
- quantile loss
|
||
- residual stat loss
|
||
|
||
---
|
||
|
||
## 6. 采样与导出流程
|
||
|
||
脚本:`example/export_samples.py`
|
||
|
||
流程:
|
||
1) 初始化噪声(连续)
|
||
2) 初始化 mask(离散)
|
||
3) 反扩散 t=T..0
|
||
4) 加回 trend
|
||
5) 反变换(quantile/标准化)
|
||
6) 合成 CSV
|
||
|
||
输出:`example/results/generated.csv`
|
||
|
||
---
|
||
|
||
## 7. 评估体系与指标
|
||
|
||
脚本:`example/evaluate_generated.py`
|
||
|
||
### 连续指标
|
||
- **KS(tie‑aware)**
|
||
- quantile diff
|
||
- lag‑1 correlation
|
||
|
||
### 离散指标
|
||
- JSD
|
||
- invalid token 比例
|
||
|
||
### Reference 读取
|
||
- 支持 `train*.csv.gz` glob
|
||
- 自动汇总所有文件
|
||
|
||
---
|
||
|
||
## 8. 诊断工具与常用脚本
|
||
|
||
- `diagnose_ks.py`:CDF 可视化
|
||
- `ranked_ks.py`:KS 贡献排序
|
||
- `filtered_metrics.py`:过滤异常特征后的 KS
|
||
- `program_stats.py`:Type1 统计
|
||
- `controller_stats.py`:Type2 统计
|
||
- `actuator_stats.py`:Type3 统计
|
||
- `pv_stats.py`:Type4 统计
|
||
- `aux_stats.py`:Type6 统计
|
||
|
||
---
|
||
|
||
## 9. Type‑aware 设计(按类型分治)
|
||
|
||
在真实 ICS 中,部分变量很难用 DDPM 学到,所以做类型划分:
|
||
|
||
- **Type1**:setpoint/demand(调度驱动)
|
||
- **Type2**:controller outputs
|
||
- **Type3**:actuator positions
|
||
- **Type4**:PV sensors
|
||
- **Type5**:derived tags
|
||
- **Type6**:aux/coupling
|
||
|
||
脚本:`example/postprocess_types.py`
|
||
|
||
当前实现是 **KS‑only baseline**:
|
||
- Type1/2/3/5/6 → 经验重采样
|
||
- Type4 → 仍用 diffusion
|
||
|
||
用途:
|
||
- 快速诊断“KS 最优可达上界”
|
||
- 不保证联合分布真实性
|
||
|
||
输出:`example/results/generated_post.csv`
|
||
|
||
---
|
||
|
||
## 10. 一键运行与常见命令
|
||
|
||
### 全流程(推荐)
|
||
```bash
|
||
python example/run_all.py --device cuda --config example/config.json
|
||
```
|
||
|
||
### 只评估不训练
|
||
```bash
|
||
python example/run_all.py --skip-prepare --skip-train --skip-export
|
||
```
|
||
|
||
### 只训练不评估
|
||
```bash
|
||
python example/run_all.py --skip-eval --skip-postprocess --skip-post-eval --skip-diagnostics
|
||
```
|
||
|
||
---
|
||
|
||
## 11. 输出文件说明
|
||
|
||
- `generated.csv`:原始 diffusion 输出
|
||
- `generated_post.csv`:KS‑only 后处理输出
|
||
- `eval.json`:原始评估
|
||
- `eval_post.json`:后处理评估
|
||
- `cont_stats.json` / `disc_vocab.json`:统计文件
|
||
- `*_stats.json`:Type 统计报告
|
||
|
||
---
|
||
|
||
## 12. 当前配置(关键超参)
|
||
|
||
来自 `example/config.json`:
|
||
- backbone_type: **transformer**
|
||
- timesteps: 600
|
||
- seq_len: 96
|
||
- batch_size: 16
|
||
- cont_target: x0
|
||
- cont_loss_weighting: inv_std
|
||
- snr_weighted_loss: true
|
||
- quantile_loss_weight: 0.2
|
||
- use_quantile_transform: true
|
||
- cont_post_calibrate: true
|
||
- use_temporal_stage1: true
|
||
|
||
---
|
||
|
||
## 13. 为什么运行慢
|
||
|
||
1) 两阶段训练(temporal + diffusion)
|
||
2) 评估要读全量 train*.csv.gz
|
||
3) run_all 默认跑所有诊断脚本
|
||
4) timesteps / seq_len 大
|
||
|
||
---
|
||
|
||
## 14. 已知限制与后续方向
|
||
|
||
限制:
|
||
- Type1/2/3 仍主导 KS
|
||
- KS‑only baseline 会破坏联合分布
|
||
- 时序和分布存在 trade‑off
|
||
|
||
方向:
|
||
- 为 Type1/2/3 建条件模型
|
||
- Type4 增加 regime conditioning
|
||
- 联合指标(cross‑feature correlation)
|
||
|
||
---
|
||
|
||
## 15. 文件树(精简版)
|
||
|
||
```
|
||
mask-ddpm/
|
||
report.md
|
||
docs/
|
||
README.md
|
||
architecture.md
|
||
evaluation.md
|
||
decisions.md
|
||
experiments.md
|
||
ideas.md
|
||
example/
|
||
config.json
|
||
config_no_temporal.json
|
||
config_temporal_strong.json
|
||
feature_split.json
|
||
data_utils.py
|
||
prepare_data.py
|
||
hybrid_diffusion.py
|
||
train.py
|
||
sample.py
|
||
export_samples.py
|
||
evaluate_generated.py
|
||
run_all.py
|
||
run_compare.py
|
||
diagnose_ks.py
|
||
filtered_metrics.py
|
||
ranked_ks.py
|
||
program_stats.py
|
||
controller_stats.py
|
||
actuator_stats.py
|
||
pv_stats.py
|
||
aux_stats.py
|
||
postprocess_types.py
|
||
results/
|
||
```
|
||
|
||
---
|
||
|
||
## 16. 文件职责(逐文件说明)
|
||
|
||
- `prepare_data.py`:统计连续/离散特征
|
||
- `data_utils.py`:预处理与变换函数
|
||
- `hybrid_diffusion.py`:模型主体(Temporal + Diffusion)
|
||
- `train.py`:两阶段训练
|
||
- `export_samples.py`:采样导出
|
||
- `evaluate_generated.py`:评估指标
|
||
- `run_all.py`:一键流程
|
||
- `postprocess_types.py`:Type‑aware KS‑only baseline
|
||
- `diagnose_ks.py`:CDF 诊断
|
||
- `ranked_ks.py`:KS 排序
|
||
- `filtered_metrics.py`:过滤 KS
|
||
|
||
---
|
||
|
||
# 结束
|
||
如果你需要更“论文式”的版本(加入公式、伪代码、实验表格),可以继续追加。
|