7.4 KiB
7.4 KiB
mask-ddpm 项目说明书(完整详细版)
本文档是“说明书级别”的完整描述,面向首次接触项目的同学。 目标是让不了解扩散/时序建模的人也能理解:项目是什么、怎么跑、每个文件干什么、每一步在训练什么、为什么这么设计。
适用范围:当前仓库代码(以
example/config.json为主配置)。
目录
- 项目目标与研究问题
- 数据与特征结构
- 预处理与统计文件
- 模型总体架构
- 训练流程(逐步骤)
- 采样与导出流程
- 评估体系与指标
- 诊断工具与常用脚本
- Type‑aware(按类型分治)设计
- 一键运行与常见命令
- 输出文件说明
- 当前配置与关键超参
- 常见问题与慢的原因
- 已知限制与后续方向
- 文件树(精简版)
- 文件职责(逐文件说明)
1. 项目目标与研究问题
本项目目标:生成工业控制系统(ICS)多变量时序数据,满足以下三点:
- 分布一致性:每个变量的统计分布接近真实(用 KS 衡量)
- 时序一致性:序列结构合理,lag‑1 相关性、趋势符合真实
- 离散合法性:离散变量(状态/模式)必须是合法 token 且分布合理(JSD)
核心难点:
- 时序结构和分布对齐经常相互冲突
- 真实数据包含“程序驱动/事件驱动”的变量,难以用纯 DDPM 学好
2. 数据与特征结构
数据来源:HAI train*.csv.gz(多文件)
特征拆分(见 example/feature_split.json):
continuous:连续变量(传感器/执行器)discrete:离散变量(状态/模式)time_column:时间列(不参与训练)
3. 预处理与统计文件
脚本:example/prepare_data.py
3.1 连续变量
- 计算 mean/std
- 若开启
use_quantile_transform:计算分位数表(CDF) - 输出:
example/results/cont_stats.json
3.2 离散变量
- 统计 vocab
- 输出:
example/results/disc_vocab.json
3.3 数据工具
example/data_utils.py 提供:
- 标准化/反标准化
- 分位数变换/逆变换
- 可选后校准(quantile calibration)
4. 模型总体架构
本项目采用 两阶段 + 混合扩散 架构:
4.1 Stage‑1 Temporal GRU
- 目的:学习序列趋势、时序结构
- 输入:连续变量序列
- 输出:trend(趋势序列)
4.2 Stage‑2 Hybrid Diffusion
- 目的:学习残差分布(把时序和分布解耦)
- 连续变量:Gaussian DDPM
- 离散变量:mask diffusion 分类 head
4.3 Backbone 选择
- 当前配置:
backbone_type = transformer - 可选:GRU(更省显存更稳定)
5. 训练流程(逐步骤)
脚本:example/train.py
Step 1:Temporal 训练
- 输入:连续序列
- GRU teacher‑forcing 预测下一步
- Loss:MSE
- 输出:
temporal.pt
Step 2:Diffusion 训练
- 计算残差:
x_resid = x_cont - trend - 采样时间步 t
- 连续:加噪
- 离散:mask token
- 模型预测 eps / logits
Loss 设计
- Continuous:MSE(eps 或 x0)
- Discrete:Cross Entropy(mask 部分)
- 总损失:
loss = λ * loss_cont + (1-λ) * loss_disc - 可选加权:
- inverse‑std
- SNR‑weighted
- quantile loss
- residual stat loss
6. 采样与导出流程
脚本:example/export_samples.py
流程:
- 初始化噪声(连续)
- 初始化 mask(离散)
- 反扩散 t=T..0
- 加回 trend
- 反变换(quantile/标准化)
- 合成 CSV
输出:example/results/generated.csv
7. 评估体系与指标
脚本:example/evaluate_generated.py
连续指标
- KS(tie‑aware)
- quantile diff
- lag‑1 correlation
离散指标
- JSD
- invalid token 比例
Reference 读取
- 支持
train*.csv.gzglob - 自动汇总所有文件
8. 诊断工具与常用脚本
diagnose_ks.py:CDF 可视化ranked_ks.py:KS 贡献排序filtered_metrics.py:过滤异常特征后的 KSprogram_stats.py:Type1 统计controller_stats.py:Type2 统计actuator_stats.py:Type3 统计pv_stats.py:Type4 统计aux_stats.py:Type6 统计
9. Type‑aware 设计(按类型分治)
在真实 ICS 中,部分变量很难用 DDPM 学到,所以做类型划分:
- Type1:setpoint/demand(调度驱动)
- Type2:controller outputs
- Type3:actuator positions
- Type4:PV sensors
- Type5:derived tags
- Type6:aux/coupling
脚本:example/postprocess_types.py
当前实现是 KS‑only baseline:
- Type1/2/3/5/6 → 经验重采样
- Type4 → 仍用 diffusion
用途:
- 快速诊断“KS 最优可达上界”
- 不保证联合分布真实性
输出:example/results/generated_post.csv
10. 一键运行与常见命令
全流程(推荐)
python example/run_all.py --device cuda --config example/config.json
只评估不训练
python example/run_all.py --skip-prepare --skip-train --skip-export
只训练不评估
python example/run_all.py --skip-eval --skip-postprocess --skip-post-eval --skip-diagnostics
11. 输出文件说明
generated.csv:原始 diffusion 输出generated_post.csv:KS‑only 后处理输出eval.json:原始评估eval_post.json:后处理评估cont_stats.json/disc_vocab.json:统计文件*_stats.json:Type 统计报告
12. 当前配置(关键超参)
来自 example/config.json:
- backbone_type: transformer
- timesteps: 600
- seq_len: 96
- batch_size: 16
- cont_target: x0
- cont_loss_weighting: inv_std
- snr_weighted_loss: true
- quantile_loss_weight: 0.2
- use_quantile_transform: true
- cont_post_calibrate: true
- use_temporal_stage1: true
13. 为什么运行慢
- 两阶段训练(temporal + diffusion)
- 评估要读全量 train*.csv.gz
- run_all 默认跑所有诊断脚本
- timesteps / seq_len 大
14. 已知限制与后续方向
限制:
- Type1/2/3 仍主导 KS
- KS‑only baseline 会破坏联合分布
- 时序和分布存在 trade‑off
方向:
- 为 Type1/2/3 建条件模型
- Type4 增加 regime conditioning
- 联合指标(cross‑feature correlation)
15. 文件树(精简版)
mask-ddpm/
report.md
docs/
README.md
architecture.md
evaluation.md
decisions.md
experiments.md
ideas.md
example/
config.json
config_no_temporal.json
config_temporal_strong.json
feature_split.json
data_utils.py
prepare_data.py
hybrid_diffusion.py
train.py
sample.py
export_samples.py
evaluate_generated.py
run_all.py
run_compare.py
diagnose_ks.py
filtered_metrics.py
ranked_ks.py
program_stats.py
controller_stats.py
actuator_stats.py
pv_stats.py
aux_stats.py
postprocess_types.py
results/
16. 文件职责(逐文件说明)
prepare_data.py:统计连续/离散特征data_utils.py:预处理与变换函数hybrid_diffusion.py:模型主体(Temporal + Diffusion)train.py:两阶段训练export_samples.py:采样导出evaluate_generated.py:评估指标run_all.py:一键流程postprocess_types.py:Type‑aware KS‑only baselinediagnose_ks.py:CDF 诊断ranked_ks.py:KS 排序filtered_metrics.py:过滤 KS
结束
如果你需要更“论文式”的版本(加入公式、伪代码、实验表格),可以继续追加。