Update example and notes
This commit is contained in:
27
CONVERSATION.md
Normal file
27
CONVERSATION.md
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
# 对话记录
|
||||||
|
|
||||||
|
日期: 2025-02-14
|
||||||
|
|
||||||
|
用户要求我先阅读 `/home/anay/Dev/diffusion/mask-ddpm` 中的 MD 文档,并按照该文档说明来推进工作。同时,在该项目中新增一个 MD 文件用于记录我们的对话,方便后续 AI 了解上下文。
|
||||||
|
|
||||||
|
已阅读文档:
|
||||||
|
- `/home/anay/Dev/diffusion/mask-ddpm/TODO-FOR-AI.md`
|
||||||
|
|
||||||
|
当前状态:
|
||||||
|
- 完成了文档阅读。
|
||||||
|
- 已创建本对话记录文件。
|
||||||
|
|
||||||
|
后续记录:
|
||||||
|
- 已开始检查数据集。发现 HAI 的 CSV 文件是 Git LFS 指针文件,需要拉取真实内容;Modbus 数据集当前可见的 CSV 主要是攻击日志(字段如 Timestamp/TargetIP/Attack/TransactionID),其主体数据为 PCAP,需要额外特征提取。
|
||||||
|
- 已阅读 `/home/anay/Dev/diffusion/dataset/hai/README.md`。关键信息:HAI CSV 是时间序列特征数据,第一列为时间戳,后续为 SCADA 点位;末尾为攻击标签列;HAI 22.04 起使用 Git LFS 下载实际 CSV 内容。
|
||||||
|
- 用户尝试 `git lfs pull` 时遇到 LFS 配额超限(上游仓库配额不足),导致无法拉取 HAI 的真实 CSV 内容。
|
||||||
|
- 用户选择从 Kaggle 下载 HAI 数据集并手动放入本地目录。
|
||||||
|
- 在本地发现可读取的 HAI CSV 为 gzip 格式:`/home/anay/Dev/diffusion/dataset/hai/hai-20.07/*.csv.gz` 与 `/home/anay/Dev/diffusion/dataset/hai/hai-21.03/*.csv.gz`,列名包含大量 SCADA 点位与攻击标签列。
|
||||||
|
- 基于 `hai-21.03/train1.csv.gz` 前 5000 行的启发式统计,初步划分了离散/连续特征(离散多为开关、状态与攻击标签;连续为传感器/过程变量)。
|
||||||
|
- 在 `/home/anay/Dev/diffusion/mask-ddpm/example` 中创建了示例文件夹,包含 `analyze_hai21_03.py` 和结果输出(`results/feature_split.txt`、`results/summary.txt`)。
|
||||||
|
- 已补充示例代码与文档:`feature_split.json`、`hybrid_diffusion.py`、`train_stub.py`、`model_design.md`,并更新了 `/home/anay/Dev/diffusion/mask-ddpm/example/README.md`。
|
||||||
|
- 新增可运行脚本与数据准备:`data_utils.py`、`prepare_data.py`、`train.py`、`sample.py`,并修正 `train_stub.py` 以匹配新的离散掩码接口。
|
||||||
|
- 已运行 `prepare_data.py` 生成 `results/cont_stats.json` 与 `results/disc_vocab.json`(采样 50k 行)。
|
||||||
|
- 用户计划使用 conda 创建新环境并安装 GPU 版 PyTorch;建议环境名 `mask-ddpm`,Python 3.10,使用 cu121 安装源。
|
||||||
|
- 已运行 `example/train.py`(CPU 回退,CUDA 初始化警告),生成 `example/results/model.pt`;已运行 `example/sample.py` 得到采样张量形状 `(2, 64, 53)` 和 `(2, 64, 30)`,并将加载方式改为 `weights_only=True`。
|
||||||
|
- 已将 `train.py` 和 `sample.py` 改为自动选择 GPU(可用则使用,否则回退 CPU),并记录当前环境下 `/dev/nvidia*` 不存在导致 CUDA 不可用。
|
||||||
44
example/README.md
Normal file
44
example/README.md
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
# Example: HAI 21.03 Feature Split
|
||||||
|
|
||||||
|
This folder contains a small, reproducible example that inspects the HAI 21.03
|
||||||
|
CSV (train1) and produces a continuous/discrete split using a simple heuristic.
|
||||||
|
|
||||||
|
## Files
|
||||||
|
- analyze_hai21_03.py: reads a sample of the data and writes results.
|
||||||
|
- data_utils.py: CSV loading, vocab, normalization, and batching helpers.
|
||||||
|
- feature_split.json: column split for HAI 21.03.
|
||||||
|
- hybrid_diffusion.py: hybrid model + diffusion utilities.
|
||||||
|
- prepare_data.py: compute vocab and normalization stats.
|
||||||
|
- train_stub.py: end-to-end scaffold for loss computation.
|
||||||
|
- train.py: minimal training loop with checkpoints.
|
||||||
|
- sample.py: minimal sampling loop.
|
||||||
|
- model_design.md: step-by-step design notes.
|
||||||
|
- results/feature_split.txt: comma-separated feature lists.
|
||||||
|
- results/summary.txt: basic stats (rows sampled, column counts).
|
||||||
|
|
||||||
|
## Run
|
||||||
|
```
|
||||||
|
python /home/anay/Dev/diffusion/mask-ddpm/example/analyze_hai21_03.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Prepare vocab + stats (writes to `example/results`):
|
||||||
|
```
|
||||||
|
python /home/anay/Dev/diffusion/mask-ddpm/example/prepare_data.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Train a small run:
|
||||||
|
```
|
||||||
|
python /home/anay/Dev/diffusion/mask-ddpm/example/train.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Sample from the trained model:
|
||||||
|
```
|
||||||
|
python /home/anay/Dev/diffusion/mask-ddpm/example/sample.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
- Heuristic: integer-like values with low cardinality (<=10) are treated as
|
||||||
|
discrete. All other numeric columns are continuous.
|
||||||
|
- The script only samples the first 5000 rows to stay fast.
|
||||||
|
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.
|
||||||
|
- `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU.
|
||||||
BIN
example/__pycache__/data_utils.cpython-39.pyc
Normal file
BIN
example/__pycache__/data_utils.cpython-39.pyc
Normal file
Binary file not shown.
BIN
example/__pycache__/hybrid_diffusion.cpython-39.pyc
Normal file
BIN
example/__pycache__/hybrid_diffusion.cpython-39.pyc
Normal file
Binary file not shown.
104
example/analyze_hai21_03.py
Executable file
104
example/analyze_hai21_03.py
Executable file
@@ -0,0 +1,104 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Analyze HAI 21.03 CSV to split features into continuous/discrete.
|
||||||
|
|
||||||
|
Heuristic: integer-like values with low cardinality (<=10) -> discrete.
|
||||||
|
Everything else numeric -> continuous. Non-numeric -> discrete.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import gzip
|
||||||
|
import os
|
||||||
|
|
||||||
|
DATA_PATH = "/home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz"
|
||||||
|
OUT_DIR = "/home/anay/Dev/diffusion/mask-ddpm/example/results"
|
||||||
|
MAX_ROWS = 5000
|
||||||
|
|
||||||
|
|
||||||
|
def analyze(path: str, max_rows: int):
|
||||||
|
with gzip.open(path, "rt", newline="") as f:
|
||||||
|
reader = csv.reader(f)
|
||||||
|
cols = next(reader)
|
||||||
|
stats = {
|
||||||
|
c: {"numeric": True, "int_like": True, "unique": set(), "count": 0}
|
||||||
|
for c in cols
|
||||||
|
}
|
||||||
|
rows = 0
|
||||||
|
for row in reader:
|
||||||
|
rows += 1
|
||||||
|
for c, v in zip(cols, row):
|
||||||
|
st = stats[c]
|
||||||
|
if v == "" or v is None:
|
||||||
|
continue
|
||||||
|
st["count"] += 1
|
||||||
|
if st["numeric"]:
|
||||||
|
try:
|
||||||
|
fv = float(v)
|
||||||
|
except Exception:
|
||||||
|
st["numeric"] = False
|
||||||
|
st["int_like"] = False
|
||||||
|
st["unique"].add(v)
|
||||||
|
continue
|
||||||
|
if st["int_like"] and abs(fv - round(fv)) > 1e-9:
|
||||||
|
st["int_like"] = False
|
||||||
|
if len(st["unique"]) < 50:
|
||||||
|
st["unique"].add(fv)
|
||||||
|
else:
|
||||||
|
if len(st["unique"]) < 50:
|
||||||
|
st["unique"].add(v)
|
||||||
|
if rows >= max_rows:
|
||||||
|
break
|
||||||
|
|
||||||
|
continuous = []
|
||||||
|
discrete = []
|
||||||
|
unknown = []
|
||||||
|
for c in cols:
|
||||||
|
if c == "time":
|
||||||
|
continue
|
||||||
|
st = stats[c]
|
||||||
|
if st["count"] == 0:
|
||||||
|
unknown.append(c)
|
||||||
|
continue
|
||||||
|
if not st["numeric"]:
|
||||||
|
discrete.append(c)
|
||||||
|
continue
|
||||||
|
unique_count = len(st["unique"])
|
||||||
|
if st["int_like"] and unique_count <= 10:
|
||||||
|
discrete.append(c)
|
||||||
|
else:
|
||||||
|
continuous.append(c)
|
||||||
|
|
||||||
|
return cols, continuous, discrete, unknown, rows
|
||||||
|
|
||||||
|
|
||||||
|
def write_results(cols, continuous, discrete, unknown, rows):
|
||||||
|
os.makedirs(OUT_DIR, exist_ok=True)
|
||||||
|
split_path = os.path.join(OUT_DIR, "feature_split.txt")
|
||||||
|
summary_path = os.path.join(OUT_DIR, "summary.txt")
|
||||||
|
|
||||||
|
with open(split_path, "w", encoding="ascii") as f:
|
||||||
|
f.write("discrete\n")
|
||||||
|
f.write(",".join(discrete) + "\n")
|
||||||
|
f.write("continuous\n")
|
||||||
|
f.write(",".join(continuous) + "\n")
|
||||||
|
if unknown:
|
||||||
|
f.write("unknown\n")
|
||||||
|
f.write(",".join(unknown) + "\n")
|
||||||
|
|
||||||
|
with open(summary_path, "w", encoding="ascii") as f:
|
||||||
|
f.write("rows_sampled: %d\n" % rows)
|
||||||
|
f.write("columns_total: %d\n" % len(cols))
|
||||||
|
f.write("continuous: %d\n" % len(continuous))
|
||||||
|
f.write("discrete: %d\n" % len(discrete))
|
||||||
|
f.write("unknown: %d\n" % len(unknown))
|
||||||
|
f.write("data_path: %s\n" % DATA_PATH)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if not os.path.exists(DATA_PATH):
|
||||||
|
raise SystemExit("missing data file: %s" % DATA_PATH)
|
||||||
|
cols, continuous, discrete, unknown, rows = analyze(DATA_PATH, MAX_ROWS)
|
||||||
|
write_results(cols, continuous, discrete, unknown, rows)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
124
example/data_utils.py
Executable file
124
example/data_utils.py
Executable file
@@ -0,0 +1,124 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Small utilities for HAI 21.03 data loading and feature encoding."""
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import gzip
|
||||||
|
import json
|
||||||
|
from typing import Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def load_split(path: str) -> Dict[str, List[str]]:
|
||||||
|
with open(path, "r", encoding="ascii") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def iter_rows(path: str) -> Iterable[Dict[str, str]]:
|
||||||
|
with gzip.open(path, "rt", newline="") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
yield row
|
||||||
|
|
||||||
|
|
||||||
|
def compute_cont_stats(
|
||||||
|
path: str,
|
||||||
|
cont_cols: List[str],
|
||||||
|
max_rows: Optional[int] = None,
|
||||||
|
) -> Tuple[Dict[str, float], Dict[str, float]]:
|
||||||
|
"""Streaming mean/std (Welford)."""
|
||||||
|
count = 0
|
||||||
|
mean = {c: 0.0 for c in cont_cols}
|
||||||
|
m2 = {c: 0.0 for c in cont_cols}
|
||||||
|
|
||||||
|
for i, row in enumerate(iter_rows(path)):
|
||||||
|
count += 1
|
||||||
|
for c in cont_cols:
|
||||||
|
x = float(row[c])
|
||||||
|
delta = x - mean[c]
|
||||||
|
mean[c] += delta / count
|
||||||
|
delta2 = x - mean[c]
|
||||||
|
m2[c] += delta * delta2
|
||||||
|
if max_rows is not None and i + 1 >= max_rows:
|
||||||
|
break
|
||||||
|
|
||||||
|
std = {}
|
||||||
|
for c in cont_cols:
|
||||||
|
if count > 1:
|
||||||
|
var = m2[c] / (count - 1)
|
||||||
|
else:
|
||||||
|
var = 0.0
|
||||||
|
std[c] = var ** 0.5 if var > 0 else 1.0
|
||||||
|
return mean, std
|
||||||
|
|
||||||
|
|
||||||
|
def build_vocab(
|
||||||
|
path: str,
|
||||||
|
disc_cols: List[str],
|
||||||
|
max_rows: Optional[int] = None,
|
||||||
|
) -> Dict[str, Dict[str, int]]:
|
||||||
|
values = {c: set() for c in disc_cols}
|
||||||
|
for i, row in enumerate(iter_rows(path)):
|
||||||
|
for c in disc_cols:
|
||||||
|
values[c].add(row[c])
|
||||||
|
if max_rows is not None and i + 1 >= max_rows:
|
||||||
|
break
|
||||||
|
|
||||||
|
vocab = {}
|
||||||
|
for c in disc_cols:
|
||||||
|
tokens = sorted(values[c])
|
||||||
|
vocab[c] = {tok: idx for idx, tok in enumerate(tokens)}
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_cont(x, cont_cols: List[str], mean: Dict[str, float], std: Dict[str, float]):
|
||||||
|
import torch
|
||||||
|
mean_t = torch.tensor([mean[c] for c in cont_cols], dtype=x.dtype, device=x.device)
|
||||||
|
std_t = torch.tensor([std[c] for c in cont_cols], dtype=x.dtype, device=x.device)
|
||||||
|
return (x - mean_t) / std_t
|
||||||
|
|
||||||
|
|
||||||
|
def windowed_batches(
|
||||||
|
path: str,
|
||||||
|
cont_cols: List[str],
|
||||||
|
disc_cols: List[str],
|
||||||
|
vocab: Dict[str, Dict[str, int]],
|
||||||
|
mean: Dict[str, float],
|
||||||
|
std: Dict[str, float],
|
||||||
|
batch_size: int,
|
||||||
|
seq_len: int,
|
||||||
|
max_batches: Optional[int] = None,
|
||||||
|
):
|
||||||
|
import torch
|
||||||
|
batch_cont = []
|
||||||
|
batch_disc = []
|
||||||
|
seq_cont = []
|
||||||
|
seq_disc = []
|
||||||
|
|
||||||
|
def flush_seq():
|
||||||
|
nonlocal seq_cont, seq_disc, batch_cont, batch_disc
|
||||||
|
if len(seq_cont) == seq_len:
|
||||||
|
batch_cont.append(seq_cont)
|
||||||
|
batch_disc.append(seq_disc)
|
||||||
|
seq_cont = []
|
||||||
|
seq_disc = []
|
||||||
|
|
||||||
|
batches_yielded = 0
|
||||||
|
for row in iter_rows(path):
|
||||||
|
cont_row = [float(row[c]) for c in cont_cols]
|
||||||
|
disc_row = [vocab[c][row[c]] for c in disc_cols]
|
||||||
|
seq_cont.append(cont_row)
|
||||||
|
seq_disc.append(disc_row)
|
||||||
|
if len(seq_cont) == seq_len:
|
||||||
|
flush_seq()
|
||||||
|
if len(batch_cont) == batch_size:
|
||||||
|
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
|
||||||
|
x_disc = torch.tensor(batch_disc, dtype=torch.long)
|
||||||
|
x_cont = normalize_cont(x_cont, cont_cols, mean, std)
|
||||||
|
yield x_cont, x_disc
|
||||||
|
batch_cont = []
|
||||||
|
batch_disc = []
|
||||||
|
batches_yielded += 1
|
||||||
|
if max_batches is not None and batches_yielded >= max_batches:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Drop last partial batch for simplicity
|
||||||
90
example/feature_split.json
Normal file
90
example/feature_split.json
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
{
|
||||||
|
"time_column": "time",
|
||||||
|
"continuous": [
|
||||||
|
"P1_B2004",
|
||||||
|
"P1_B2016",
|
||||||
|
"P1_B3004",
|
||||||
|
"P1_B3005",
|
||||||
|
"P1_B4002",
|
||||||
|
"P1_B4005",
|
||||||
|
"P1_B400B",
|
||||||
|
"P1_B4022",
|
||||||
|
"P1_FCV02Z",
|
||||||
|
"P1_FCV03D",
|
||||||
|
"P1_FCV03Z",
|
||||||
|
"P1_FT01",
|
||||||
|
"P1_FT01Z",
|
||||||
|
"P1_FT02",
|
||||||
|
"P1_FT02Z",
|
||||||
|
"P1_FT03",
|
||||||
|
"P1_FT03Z",
|
||||||
|
"P1_LCV01D",
|
||||||
|
"P1_LCV01Z",
|
||||||
|
"P1_LIT01",
|
||||||
|
"P1_PCV01D",
|
||||||
|
"P1_PCV01Z",
|
||||||
|
"P1_PCV02Z",
|
||||||
|
"P1_PIT01",
|
||||||
|
"P1_PIT02",
|
||||||
|
"P1_TIT01",
|
||||||
|
"P1_TIT02",
|
||||||
|
"P2_24Vdc",
|
||||||
|
"P2_CO_rpm",
|
||||||
|
"P2_HILout",
|
||||||
|
"P2_MSD",
|
||||||
|
"P2_SIT01",
|
||||||
|
"P2_SIT02",
|
||||||
|
"P2_VT01",
|
||||||
|
"P2_VXT02",
|
||||||
|
"P2_VXT03",
|
||||||
|
"P2_VYT02",
|
||||||
|
"P2_VYT03",
|
||||||
|
"P3_FIT01",
|
||||||
|
"P3_LCP01D",
|
||||||
|
"P3_LCV01D",
|
||||||
|
"P3_LIT01",
|
||||||
|
"P3_PIT01",
|
||||||
|
"P4_HT_FD",
|
||||||
|
"P4_HT_LD",
|
||||||
|
"P4_HT_PO",
|
||||||
|
"P4_LD",
|
||||||
|
"P4_ST_FD",
|
||||||
|
"P4_ST_GOV",
|
||||||
|
"P4_ST_LD",
|
||||||
|
"P4_ST_PO",
|
||||||
|
"P4_ST_PT01",
|
||||||
|
"P4_ST_TT01"
|
||||||
|
],
|
||||||
|
"discrete": [
|
||||||
|
"P1_FCV01D",
|
||||||
|
"P1_FCV01Z",
|
||||||
|
"P1_FCV02D",
|
||||||
|
"P1_PCV02D",
|
||||||
|
"P1_PP01AD",
|
||||||
|
"P1_PP01AR",
|
||||||
|
"P1_PP01BD",
|
||||||
|
"P1_PP01BR",
|
||||||
|
"P1_PP02D",
|
||||||
|
"P1_PP02R",
|
||||||
|
"P1_STSP",
|
||||||
|
"P2_ASD",
|
||||||
|
"P2_AutoGO",
|
||||||
|
"P2_Emerg",
|
||||||
|
"P2_ManualGO",
|
||||||
|
"P2_OnOff",
|
||||||
|
"P2_RTR",
|
||||||
|
"P2_TripEx",
|
||||||
|
"P2_VTR01",
|
||||||
|
"P2_VTR02",
|
||||||
|
"P2_VTR03",
|
||||||
|
"P2_VTR04",
|
||||||
|
"P3_LH",
|
||||||
|
"P3_LL",
|
||||||
|
"P4_HT_PS",
|
||||||
|
"P4_ST_PS",
|
||||||
|
"attack",
|
||||||
|
"attack_P1",
|
||||||
|
"attack_P2",
|
||||||
|
"attack_P3"
|
||||||
|
]
|
||||||
|
}
|
||||||
113
example/hybrid_diffusion.py
Executable file
113
example/hybrid_diffusion.py
Executable file
@@ -0,0 +1,113 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Hybrid diffusion scaffold for continuous + discrete HAI features.
|
||||||
|
|
||||||
|
Continuous: Gaussian diffusion (DDPM-style).
|
||||||
|
Discrete: mask-based diffusion (predict original token).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_beta_schedule(timesteps: int, s: float = 0.008) -> torch.Tensor:
|
||||||
|
steps = timesteps + 1
|
||||||
|
x = torch.linspace(0, timesteps, steps)
|
||||||
|
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
|
||||||
|
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||||
|
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||||
|
return torch.clip(betas, 1e-5, 0.999)
|
||||||
|
|
||||||
|
|
||||||
|
def q_sample_continuous(x0: torch.Tensor, t: torch.Tensor, alphas_cumprod: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Add Gaussian noise to continuous features at timestep t."""
|
||||||
|
noise = torch.randn_like(x0)
|
||||||
|
a_bar = alphas_cumprod[t].view(-1, 1, 1)
|
||||||
|
xt = torch.sqrt(a_bar) * x0 + torch.sqrt(1.0 - a_bar) * noise
|
||||||
|
return xt, noise
|
||||||
|
|
||||||
|
|
||||||
|
def q_sample_discrete(
|
||||||
|
x0: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
mask_tokens: torch.Tensor,
|
||||||
|
max_t: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Randomly mask discrete tokens with a linear schedule over t."""
|
||||||
|
bsz = x0.size(0)
|
||||||
|
p = t.float() / float(max_t)
|
||||||
|
p = p.view(bsz, 1, 1)
|
||||||
|
mask = torch.rand_like(x0.float()) < p
|
||||||
|
x_masked = x0.clone()
|
||||||
|
for i in range(x0.size(2)):
|
||||||
|
x_masked[:, :, i][mask[:, :, i]] = mask_tokens[i]
|
||||||
|
return x_masked, mask
|
||||||
|
|
||||||
|
|
||||||
|
class SinusoidalTimeEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
||||||
|
half = self.dim // 2
|
||||||
|
freqs = torch.exp(-math.log(10000) * torch.arange(0, half, device=t.device) / half)
|
||||||
|
args = t.float().unsqueeze(1) * freqs.unsqueeze(0)
|
||||||
|
emb = torch.cat([torch.sin(args), torch.cos(args)], dim=1)
|
||||||
|
if self.dim % 2 == 1:
|
||||||
|
emb = F.pad(emb, (0, 1))
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class HybridDiffusionModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cont_dim: int,
|
||||||
|
disc_vocab_sizes: List[int],
|
||||||
|
time_dim: int = 64,
|
||||||
|
hidden_dim: int = 256,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.cont_dim = cont_dim
|
||||||
|
self.disc_vocab_sizes = disc_vocab_sizes
|
||||||
|
|
||||||
|
self.time_embed = SinusoidalTimeEmbedding(time_dim)
|
||||||
|
|
||||||
|
self.disc_embeds = nn.ModuleList([
|
||||||
|
nn.Embedding(vocab_size + 1, min(32, vocab_size * 2))
|
||||||
|
for vocab_size in disc_vocab_sizes
|
||||||
|
])
|
||||||
|
disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds)
|
||||||
|
|
||||||
|
self.cont_proj = nn.Linear(cont_dim, cont_dim)
|
||||||
|
self.in_proj = nn.Linear(cont_dim + disc_embed_dim + time_dim, hidden_dim)
|
||||||
|
self.backbone = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
|
||||||
|
|
||||||
|
self.cont_head = nn.Linear(hidden_dim, cont_dim)
|
||||||
|
self.disc_heads = nn.ModuleList([
|
||||||
|
nn.Linear(hidden_dim, vocab_size)
|
||||||
|
for vocab_size in disc_vocab_sizes
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x_cont: torch.Tensor, x_disc: torch.Tensor, t: torch.Tensor):
|
||||||
|
"""x_cont: (B,T,Cc), x_disc: (B,T,Cd) with integer tokens."""
|
||||||
|
time_emb = self.time_embed(t)
|
||||||
|
time_emb = time_emb.unsqueeze(1).expand(-1, x_cont.size(1), -1)
|
||||||
|
|
||||||
|
disc_embs = []
|
||||||
|
for i, emb in enumerate(self.disc_embeds):
|
||||||
|
disc_embs.append(emb(x_disc[:, :, i]))
|
||||||
|
disc_feat = torch.cat(disc_embs, dim=-1)
|
||||||
|
|
||||||
|
cont_feat = self.cont_proj(x_cont)
|
||||||
|
feat = torch.cat([cont_feat, disc_feat, time_emb], dim=-1)
|
||||||
|
feat = self.in_proj(feat)
|
||||||
|
|
||||||
|
out, _ = self.backbone(feat)
|
||||||
|
|
||||||
|
eps_pred = self.cont_head(out)
|
||||||
|
logits = [head(out) for head in self.disc_heads]
|
||||||
|
return eps_pred, logits
|
||||||
45
example/model_design.md
Normal file
45
example/model_design.md
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# Hybrid Diffusion Design (HAI 21.03)
|
||||||
|
|
||||||
|
## 1) Data representation
|
||||||
|
- Input sequence length: T (e.g., 64 or 128 time steps).
|
||||||
|
- Continuous features: 53 columns (sensor/process values).
|
||||||
|
- Discrete features: 30 columns (binary or low-cardinality states + attack labels).
|
||||||
|
- Time column: `time` is excluded from modeling; use index-based position/time embeddings.
|
||||||
|
|
||||||
|
## 2) Forward processes
|
||||||
|
### Continuous (Gaussian DDPM)
|
||||||
|
- Use cosine beta schedule with `timesteps=1000`.
|
||||||
|
- Forward: `x_t = sqrt(a_bar_t) * x_0 + sqrt(1-a_bar_t) * eps`.
|
||||||
|
|
||||||
|
### Discrete (mask diffusion)
|
||||||
|
- Use `[MASK]` replacement with probability `p(t)`.
|
||||||
|
- Simple schedule: `p(t) = t / T`.
|
||||||
|
- Model predicts original token at masked positions only.
|
||||||
|
|
||||||
|
## 3) Shared backbone + heads
|
||||||
|
- Inputs: concatenated continuous projection + discrete embeddings + time embedding.
|
||||||
|
- Backbone: GRU or temporal transformer.
|
||||||
|
- Heads:
|
||||||
|
- Continuous head predicts noise `eps`.
|
||||||
|
- Discrete heads predict logits per discrete feature.
|
||||||
|
|
||||||
|
## 4) Loss
|
||||||
|
- Continuous: `L_cont = MSE(eps_pred, eps)`.
|
||||||
|
- Discrete: `L_disc = CE(logits, target)` on masked positions only.
|
||||||
|
- Combined: `L = lambda * L_cont + (1 - lambda) * L_disc`.
|
||||||
|
|
||||||
|
## 5) Training loop (high level)
|
||||||
|
1. Load a batch of sequences.
|
||||||
|
2. Sample timesteps `t`.
|
||||||
|
3. Apply `q_sample_continuous` and `q_sample_discrete`.
|
||||||
|
4. Forward model, compute losses.
|
||||||
|
5. Backprop + optimizer step.
|
||||||
|
|
||||||
|
## 6) Sampling (high level)
|
||||||
|
- Continuous: standard reverse diffusion from pure noise.
|
||||||
|
- Discrete: start from all `[MASK]` and iteratively refine tokens.
|
||||||
|
|
||||||
|
## 7) Files in this example
|
||||||
|
- `feature_split.json`: column split for HAI 21.03.
|
||||||
|
- `hybrid_diffusion.py`: model + diffusion utilities.
|
||||||
|
- `train_stub.py`: end-to-end scaffold for loss computation.
|
||||||
32
example/prepare_data.py
Executable file
32
example/prepare_data.py
Executable file
@@ -0,0 +1,32 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Prepare vocab and normalization stats for HAI 21.03."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from data_utils import compute_cont_stats, build_vocab, load_split
|
||||||
|
|
||||||
|
DATA_PATH = "/home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz"
|
||||||
|
SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json"
|
||||||
|
OUT_STATS = "/home/anay/Dev/diffusion/mask-ddpm/example/results/cont_stats.json"
|
||||||
|
OUT_VOCAB = "/home/anay/Dev/diffusion/mask-ddpm/example/results/disc_vocab.json"
|
||||||
|
|
||||||
|
|
||||||
|
def main(max_rows: Optional[int] = None):
|
||||||
|
split = load_split(SPLIT_PATH)
|
||||||
|
cont_cols = split["continuous"]
|
||||||
|
disc_cols = split["discrete"]
|
||||||
|
|
||||||
|
mean, std = compute_cont_stats(DATA_PATH, cont_cols, max_rows=max_rows)
|
||||||
|
vocab = build_vocab(DATA_PATH, disc_cols, max_rows=max_rows)
|
||||||
|
|
||||||
|
with open(OUT_STATS, "w", encoding="ascii") as f:
|
||||||
|
json.dump({"mean": mean, "std": std, "max_rows": max_rows}, f, indent=2)
|
||||||
|
|
||||||
|
with open(OUT_VOCAB, "w", encoding="ascii") as f:
|
||||||
|
json.dump({"vocab": vocab, "max_rows": max_rows}, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Default: sample 50000 rows for speed. Set to None for full scan.
|
||||||
|
main(max_rows=50000)
|
||||||
113
example/results/cont_stats.json
Normal file
113
example/results/cont_stats.json
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
{
|
||||||
|
"mean": {
|
||||||
|
"P1_B2004": 0.08649086820000026,
|
||||||
|
"P1_B2016": 1.376161456000001,
|
||||||
|
"P1_B3004": 396.1861596906018,
|
||||||
|
"P1_B3005": 1037.372384413793,
|
||||||
|
"P1_B4002": 32.564872940799994,
|
||||||
|
"P1_B4005": 65.98190757240047,
|
||||||
|
"P1_B400B": 1925.0391570245934,
|
||||||
|
"P1_B4022": 36.28908066800001,
|
||||||
|
"P1_FCV02Z": 21.744261118400036,
|
||||||
|
"P1_FCV03D": 57.36123274140044,
|
||||||
|
"P1_FCV03Z": 58.05084519640002,
|
||||||
|
"P1_FT01": 184.18615112319728,
|
||||||
|
"P1_FT01Z": 851.8781750705965,
|
||||||
|
"P1_FT02": 1255.8572173544069,
|
||||||
|
"P1_FT02Z": 1925.0210755194114,
|
||||||
|
"P1_FT03": 269.37285885780574,
|
||||||
|
"P1_FT03Z": 1037.366172230601,
|
||||||
|
"P1_LCV01D": 11.228849048599963,
|
||||||
|
"P1_LCV01Z": 10.991610181600016,
|
||||||
|
"P1_LIT01": 396.8845311109994,
|
||||||
|
"P1_PCV01D": 53.80101618419986,
|
||||||
|
"P1_PCV01Z": 54.646640287199595,
|
||||||
|
"P1_PCV02Z": 12.017773542800072,
|
||||||
|
"P1_PIT01": 1.3692859488000075,
|
||||||
|
"P1_PIT02": 0.44459071260000227,
|
||||||
|
"P1_TIT01": 35.64255813999988,
|
||||||
|
"P1_TIT02": 36.44807823060023,
|
||||||
|
"P2_24Vdc": 28.0280019013999,
|
||||||
|
"P2_CO_rpm": 54105.64434999997,
|
||||||
|
"P2_HILout": 712.0588667425922,
|
||||||
|
"P2_MSD": 763.19324,
|
||||||
|
"P2_SIT01": 778.7769850000013,
|
||||||
|
"P2_SIT02": 778.7778935471981,
|
||||||
|
"P2_VT01": 11.914949448200044,
|
||||||
|
"P2_VXT02": -3.5267871940000175,
|
||||||
|
"P2_VXT03": -1.5520904921999914,
|
||||||
|
"P2_VYT02": 3.796112737600002,
|
||||||
|
"P2_VYT03": 6.121691697000018,
|
||||||
|
"P3_FIT01": 1168.2528800000014,
|
||||||
|
"P3_LCP01D": 4675.465239999989,
|
||||||
|
"P3_LCV01D": 7445.208720000017,
|
||||||
|
"P3_LIT01": 13728.982314999852,
|
||||||
|
"P3_PIT01": 668.9722350000003,
|
||||||
|
"P4_HT_FD": -0.00010012580000000082,
|
||||||
|
"P4_HT_LD": 35.41945000099953,
|
||||||
|
"P4_HT_PO": 35.4085699912002,
|
||||||
|
"P4_LD": 365.3833745803986,
|
||||||
|
"P4_ST_FD": -6.5205999999999635e-06,
|
||||||
|
"P4_ST_GOV": 17801.81294499996,
|
||||||
|
"P4_ST_LD": 329.83259218199964,
|
||||||
|
"P4_ST_PO": 330.1079461497967,
|
||||||
|
"P4_ST_PT01": 10047.679605000127,
|
||||||
|
"P4_ST_TT01": 27606.860070000155
|
||||||
|
},
|
||||||
|
"std": {
|
||||||
|
"P1_B2004": 0.024492489898690458,
|
||||||
|
"P1_B2016": 0.12949272564759745,
|
||||||
|
"P1_B3004": 10.16264800653289,
|
||||||
|
"P1_B3005": 70.85697659109,
|
||||||
|
"P1_B4002": 0.7578213113008356,
|
||||||
|
"P1_B4005": 41.80065314991797,
|
||||||
|
"P1_B400B": 1176.6445547448632,
|
||||||
|
"P1_B4022": 0.8221115066487089,
|
||||||
|
"P1_FCV02Z": 39.11843197764176,
|
||||||
|
"P1_FCV03D": 7.889507447726624,
|
||||||
|
"P1_FCV03Z": 8.046068905945717,
|
||||||
|
"P1_FT01": 30.80117031882856,
|
||||||
|
"P1_FT01Z": 91.2786865433318,
|
||||||
|
"P1_FT02": 879.7163277334494,
|
||||||
|
"P1_FT02Z": 1176.6699531305114,
|
||||||
|
"P1_FT03": 38.18015841964941,
|
||||||
|
"P1_FT03Z": 70.73100774436428,
|
||||||
|
"P1_LCV01D": 3.3355655415557597,
|
||||||
|
"P1_LCV01Z": 3.386332233773545,
|
||||||
|
"P1_LIT01": 10.57871476010412,
|
||||||
|
"P1_PCV01D": 19.61567943613885,
|
||||||
|
"P1_PCV01Z": 19.778754467302086,
|
||||||
|
"P1_PCV02Z": 0.004804797893159998,
|
||||||
|
"P1_PIT01": 0.0776614954053113,
|
||||||
|
"P1_PIT02": 0.44823231815652304,
|
||||||
|
"P1_TIT01": 0.5986678527528815,
|
||||||
|
"P1_TIT02": 1.1892341204521049,
|
||||||
|
"P2_24Vdc": 0.00320884250409781,
|
||||||
|
"P2_CO_rpm": 20.57547782150726,
|
||||||
|
"P2_HILout": 8.17885337990861,
|
||||||
|
"P2_MSD": 1.0,
|
||||||
|
"P2_SIT01": 3.894535775667256,
|
||||||
|
"P2_SIT02": 3.882477078857941,
|
||||||
|
"P2_VT01": 0.06812990916670243,
|
||||||
|
"P2_VXT02": 0.43104157117568803,
|
||||||
|
"P2_VXT03": 0.26894251958139775,
|
||||||
|
"P2_VYT02": 0.46109078832075856,
|
||||||
|
"P2_VYT03": 0.3059642938507547,
|
||||||
|
"P3_FIT01": 1787.2987693141868,
|
||||||
|
"P3_LCP01D": 5145.4094261812725,
|
||||||
|
"P3_LCV01D": 6785.602781765096,
|
||||||
|
"P3_LIT01": 4060.915441872745,
|
||||||
|
"P3_PIT01": 1168.1071264424027,
|
||||||
|
"P4_HT_FD": 0.002032582380617592,
|
||||||
|
"P4_HT_LD": 33.212361169253235,
|
||||||
|
"P4_HT_PO": 31.187825914515162,
|
||||||
|
"P4_LD": 59.736616589045646,
|
||||||
|
"P4_ST_FD": 0.0016428787127432496,
|
||||||
|
"P4_ST_GOV": 1740.5997458128215,
|
||||||
|
"P4_ST_LD": 35.86633288900077,
|
||||||
|
"P4_ST_PO": 32.375012735256696,
|
||||||
|
"P4_ST_PT01": 22.459962818146252,
|
||||||
|
"P4_ST_TT01": 24.745939350221477
|
||||||
|
},
|
||||||
|
"max_rows": 50000
|
||||||
|
}
|
||||||
30639
example/results/disc_vocab.json
Normal file
30639
example/results/disc_vocab.json
Normal file
File diff suppressed because it is too large
Load Diff
4
example/results/feature_split.txt
Normal file
4
example/results/feature_split.txt
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
discrete
|
||||||
|
P1_FCV01D,P1_FCV01Z,P1_FCV02D,P1_PCV02D,P1_PP01AD,P1_PP01AR,P1_PP01BD,P1_PP01BR,P1_PP02D,P1_PP02R,P1_STSP,P2_ASD,P2_AutoGO,P2_Emerg,P2_ManualGO,P2_OnOff,P2_RTR,P2_TripEx,P2_VTR01,P2_VTR02,P2_VTR03,P2_VTR04,P3_LH,P3_LL,P4_HT_PS,P4_ST_PS,attack,attack_P1,attack_P2,attack_P3
|
||||||
|
continuous
|
||||||
|
P1_B2004,P1_B2016,P1_B3004,P1_B3005,P1_B4002,P1_B4005,P1_B400B,P1_B4022,P1_FCV02Z,P1_FCV03D,P1_FCV03Z,P1_FT01,P1_FT01Z,P1_FT02,P1_FT02Z,P1_FT03,P1_FT03Z,P1_LCV01D,P1_LCV01Z,P1_LIT01,P1_PCV01D,P1_PCV01Z,P1_PCV02Z,P1_PIT01,P1_PIT02,P1_TIT01,P1_TIT02,P2_24Vdc,P2_CO_rpm,P2_HILout,P2_MSD,P2_SIT01,P2_SIT02,P2_VT01,P2_VXT02,P2_VXT03,P2_VYT02,P2_VYT03,P3_FIT01,P3_LCP01D,P3_LCV01D,P3_LIT01,P3_PIT01,P4_HT_FD,P4_HT_LD,P4_HT_PO,P4_LD,P4_ST_FD,P4_ST_GOV,P4_ST_LD,P4_ST_PO,P4_ST_PT01,P4_ST_TT01
|
||||||
BIN
example/results/model.pt
Normal file
BIN
example/results/model.pt
Normal file
Binary file not shown.
6
example/results/summary.txt
Normal file
6
example/results/summary.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
rows_sampled: 5000
|
||||||
|
columns_total: 84
|
||||||
|
continuous: 53
|
||||||
|
discrete: 30
|
||||||
|
unknown: 0
|
||||||
|
data_path: /home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz
|
||||||
88
example/sample.py
Executable file
88
example/sample.py
Executable file
@@ -0,0 +1,88 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Sampling stub for hybrid diffusion (continuous + discrete)."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from data_utils import load_split
|
||||||
|
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
|
||||||
|
|
||||||
|
SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json"
|
||||||
|
VOCAB_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/disc_vocab.json"
|
||||||
|
MODEL_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/model.pt"
|
||||||
|
|
||||||
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
TIMESTEPS = 200
|
||||||
|
SEQ_LEN = 64
|
||||||
|
BATCH_SIZE = 2
|
||||||
|
|
||||||
|
|
||||||
|
def load_vocab():
|
||||||
|
with open(VOCAB_PATH, "r", encoding="ascii") as f:
|
||||||
|
return json.load(f)["vocab"]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
split = load_split(SPLIT_PATH)
|
||||||
|
cont_cols = split["continuous"]
|
||||||
|
disc_cols = split["discrete"]
|
||||||
|
|
||||||
|
vocab = load_vocab()
|
||||||
|
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||||
|
|
||||||
|
print("device", DEVICE)
|
||||||
|
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE)
|
||||||
|
if os.path.exists(MODEL_PATH):
|
||||||
|
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True))
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)
|
||||||
|
alphas = 1.0 - betas
|
||||||
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||||
|
|
||||||
|
x_cont = torch.randn(BATCH_SIZE, SEQ_LEN, len(cont_cols), device=DEVICE)
|
||||||
|
x_disc = torch.full((BATCH_SIZE, SEQ_LEN, len(disc_cols)), 0, device=DEVICE, dtype=torch.long)
|
||||||
|
mask_tokens = torch.tensor(vocab_sizes, device=DEVICE)
|
||||||
|
|
||||||
|
# Initialize discrete with mask tokens
|
||||||
|
for i in range(len(disc_cols)):
|
||||||
|
x_disc[:, :, i] = mask_tokens[i]
|
||||||
|
|
||||||
|
for t in reversed(range(TIMESTEPS)):
|
||||||
|
t_batch = torch.full((BATCH_SIZE,), t, device=DEVICE, dtype=torch.long)
|
||||||
|
eps_pred, logits = model(x_cont, x_disc, t_batch)
|
||||||
|
|
||||||
|
# Continuous reverse step (DDPM): x_{t-1} mean
|
||||||
|
a_t = alphas[t]
|
||||||
|
a_bar_t = alphas_cumprod[t]
|
||||||
|
coef1 = 1.0 / torch.sqrt(a_t)
|
||||||
|
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
|
||||||
|
mean = coef1 * (x_cont - coef2 * eps_pred)
|
||||||
|
if t > 0:
|
||||||
|
noise = torch.randn_like(x_cont)
|
||||||
|
x_cont = mean + torch.sqrt(betas[t]) * noise
|
||||||
|
else:
|
||||||
|
x_cont = mean
|
||||||
|
|
||||||
|
# Discrete: fill masked positions by sampling logits
|
||||||
|
for i, logit in enumerate(logits):
|
||||||
|
if t == 0:
|
||||||
|
probs = F.softmax(logit, dim=-1)
|
||||||
|
x_disc[:, :, i] = torch.argmax(probs, dim=-1)
|
||||||
|
else:
|
||||||
|
mask = x_disc[:, :, i] == mask_tokens[i]
|
||||||
|
if mask.any():
|
||||||
|
probs = F.softmax(logit, dim=-1)
|
||||||
|
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(BATCH_SIZE, SEQ_LEN)
|
||||||
|
x_disc[:, :, i][mask] = sampled[mask]
|
||||||
|
|
||||||
|
print("sampled_cont_shape", tuple(x_cont.shape))
|
||||||
|
print("sampled_disc_shape", tuple(x_disc.shape))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
114
example/train.py
Executable file
114
example/train.py
Executable file
@@ -0,0 +1,114 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Train hybrid diffusion on HAI 21.03 (minimal runnable example)."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from data_utils import load_split, windowed_batches
|
||||||
|
from hybrid_diffusion import (
|
||||||
|
HybridDiffusionModel,
|
||||||
|
cosine_beta_schedule,
|
||||||
|
q_sample_continuous,
|
||||||
|
q_sample_discrete,
|
||||||
|
)
|
||||||
|
|
||||||
|
DATA_PATH = "/home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz"
|
||||||
|
SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json"
|
||||||
|
STATS_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/cont_stats.json"
|
||||||
|
VOCAB_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/disc_vocab.json"
|
||||||
|
OUT_DIR = "/home/anay/Dev/diffusion/mask-ddpm/example/results"
|
||||||
|
|
||||||
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
TIMESTEPS = 1000
|
||||||
|
BATCH_SIZE = 8
|
||||||
|
SEQ_LEN = 64
|
||||||
|
EPOCHS = 1
|
||||||
|
MAX_BATCHES = 50
|
||||||
|
LAMBDA = 0.5
|
||||||
|
LR = 1e-3
|
||||||
|
|
||||||
|
|
||||||
|
def load_stats():
|
||||||
|
with open(STATS_PATH, "r", encoding="ascii") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def load_vocab():
|
||||||
|
with open(VOCAB_PATH, "r", encoding="ascii") as f:
|
||||||
|
return json.load(f)["vocab"]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
split = load_split(SPLIT_PATH)
|
||||||
|
cont_cols = split["continuous"]
|
||||||
|
disc_cols = split["discrete"]
|
||||||
|
|
||||||
|
stats = load_stats()
|
||||||
|
mean = stats["mean"]
|
||||||
|
std = stats["std"]
|
||||||
|
vocab = load_vocab()
|
||||||
|
|
||||||
|
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||||
|
|
||||||
|
print("device", DEVICE)
|
||||||
|
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE)
|
||||||
|
opt = torch.optim.Adam(model.parameters(), lr=LR)
|
||||||
|
|
||||||
|
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)
|
||||||
|
alphas = 1.0 - betas
|
||||||
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||||
|
|
||||||
|
os.makedirs(OUT_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
for epoch in range(EPOCHS):
|
||||||
|
for step, (x_cont, x_disc) in enumerate(
|
||||||
|
windowed_batches(
|
||||||
|
DATA_PATH,
|
||||||
|
cont_cols,
|
||||||
|
disc_cols,
|
||||||
|
vocab,
|
||||||
|
mean,
|
||||||
|
std,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
seq_len=SEQ_LEN,
|
||||||
|
max_batches=MAX_BATCHES,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
x_cont = x_cont.to(DEVICE)
|
||||||
|
x_disc = x_disc.to(DEVICE)
|
||||||
|
|
||||||
|
bsz = x_cont.size(0)
|
||||||
|
t = torch.randint(0, TIMESTEPS, (bsz,), device=DEVICE)
|
||||||
|
|
||||||
|
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
|
||||||
|
|
||||||
|
mask_tokens = torch.tensor(vocab_sizes, device=DEVICE)
|
||||||
|
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, TIMESTEPS)
|
||||||
|
|
||||||
|
eps_pred, logits = model(x_cont_t, x_disc_t, t)
|
||||||
|
|
||||||
|
loss_cont = F.mse_loss(eps_pred, noise)
|
||||||
|
|
||||||
|
loss_disc = 0.0
|
||||||
|
for i, logit in enumerate(logits):
|
||||||
|
if mask[:, :, i].any():
|
||||||
|
loss_disc = loss_disc + F.cross_entropy(
|
||||||
|
logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]]
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = LAMBDA * loss_cont + (1 - LAMBDA) * loss_disc
|
||||||
|
opt.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
opt.step()
|
||||||
|
|
||||||
|
if step % 10 == 0:
|
||||||
|
print("epoch", epoch, "step", step, "loss", float(loss))
|
||||||
|
|
||||||
|
torch.save(model.state_dict(), os.path.join(OUT_DIR, "model.pt"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
113
example/train_stub.py
Executable file
113
example/train_stub.py
Executable file
@@ -0,0 +1,113 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Training stub for hybrid diffusion on HAI 21.03.
|
||||||
|
|
||||||
|
This is a scaffold that shows data loading, forward noising, and loss setup.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import gzip
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from hybrid_diffusion import (
|
||||||
|
HybridDiffusionModel,
|
||||||
|
cosine_beta_schedule,
|
||||||
|
q_sample_continuous,
|
||||||
|
q_sample_discrete,
|
||||||
|
)
|
||||||
|
|
||||||
|
DATA_PATH = "/home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz"
|
||||||
|
SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json"
|
||||||
|
DEVICE = "cpu"
|
||||||
|
TIMESTEPS = 1000
|
||||||
|
|
||||||
|
|
||||||
|
def load_split(path: str) -> Dict[str, List[str]]:
|
||||||
|
with open(path, "r", encoding="ascii") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def iter_rows(path: str):
|
||||||
|
with gzip.open(path, "rt", newline="") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
yield row
|
||||||
|
|
||||||
|
|
||||||
|
def build_vocab_sizes(path: str, disc_cols: List[str], max_rows: int = 5000) -> List[int]:
|
||||||
|
values = {c: set() for c in disc_cols}
|
||||||
|
for i, row in enumerate(iter_rows(path)):
|
||||||
|
for c in disc_cols:
|
||||||
|
v = row[c]
|
||||||
|
values[c].add(v)
|
||||||
|
if i + 1 >= max_rows:
|
||||||
|
break
|
||||||
|
sizes = [len(values[c]) for c in disc_cols]
|
||||||
|
return sizes
|
||||||
|
|
||||||
|
|
||||||
|
def load_batch(path: str, cont_cols: List[str], disc_cols: List[str], batch_size: int = 8, seq_len: int = 64):
|
||||||
|
cont = []
|
||||||
|
disc = []
|
||||||
|
current = []
|
||||||
|
for row in iter_rows(path):
|
||||||
|
cont_row = [float(row[c]) for c in cont_cols]
|
||||||
|
disc_row = [int(float(row[c])) for c in disc_cols]
|
||||||
|
current.append((cont_row, disc_row))
|
||||||
|
if len(current) == seq_len:
|
||||||
|
cont.append([r[0] for r in current])
|
||||||
|
disc.append([r[1] for r in current])
|
||||||
|
current = []
|
||||||
|
if len(cont) == batch_size:
|
||||||
|
break
|
||||||
|
x_cont = torch.tensor(cont, dtype=torch.float32)
|
||||||
|
x_disc = torch.tensor(disc, dtype=torch.long)
|
||||||
|
return x_cont, x_disc
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
split = load_split(SPLIT_PATH)
|
||||||
|
cont_cols = split["continuous"]
|
||||||
|
disc_cols = split["discrete"]
|
||||||
|
|
||||||
|
vocab_sizes = build_vocab_sizes(DATA_PATH, disc_cols)
|
||||||
|
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE)
|
||||||
|
|
||||||
|
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)
|
||||||
|
alphas = 1.0 - betas
|
||||||
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||||
|
|
||||||
|
x_cont, x_disc = load_batch(DATA_PATH, cont_cols, disc_cols)
|
||||||
|
x_cont = x_cont.to(DEVICE)
|
||||||
|
x_disc = x_disc.to(DEVICE)
|
||||||
|
|
||||||
|
bsz = x_cont.size(0)
|
||||||
|
t = torch.randint(0, TIMESTEPS, (bsz,), device=DEVICE)
|
||||||
|
|
||||||
|
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
|
||||||
|
|
||||||
|
mask_tokens = torch.tensor(vocab_sizes, device=DEVICE)
|
||||||
|
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, TIMESTEPS)
|
||||||
|
|
||||||
|
eps_pred, logits = model(x_cont_t, x_disc_t, t)
|
||||||
|
|
||||||
|
loss_cont = F.mse_loss(eps_pred, noise)
|
||||||
|
|
||||||
|
loss_disc = 0.0
|
||||||
|
for i, logit in enumerate(logits):
|
||||||
|
# flatten
|
||||||
|
target = x_disc[:, :, i]
|
||||||
|
if mask.any():
|
||||||
|
loss_disc = loss_disc + F.cross_entropy(logit[mask[:, :, i]], target[mask[:, :, i]])
|
||||||
|
|
||||||
|
lam = 0.5
|
||||||
|
loss = lam * loss_cont + (1 - lam) * loss_disc
|
||||||
|
print("loss_cont", float(loss_cont), "loss_disc", float(loss_disc), "loss", float(loss))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user