Clean artifacts and update example pipeline
This commit is contained in:
31
.gitignore
vendored
Normal file
31
.gitignore
vendored
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
.DS_Store
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
|
||||||
|
# Python envs
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
env/
|
||||||
|
|
||||||
|
# Datasets and large artifacts
|
||||||
|
dataset/
|
||||||
|
*.pcap
|
||||||
|
*.pcapng
|
||||||
|
*.gz
|
||||||
|
*.zip
|
||||||
|
*.tar
|
||||||
|
*.tar.gz
|
||||||
|
|
||||||
|
# Model artifacts and results
|
||||||
|
mask-ddpm/example/results/
|
||||||
|
*.pt
|
||||||
|
*.pth
|
||||||
|
*.ckpt
|
||||||
|
*.png
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
@@ -12,33 +12,54 @@ CSV (train1) and produces a continuous/discrete split using a simple heuristic.
|
|||||||
- train_stub.py: end-to-end scaffold for loss computation.
|
- train_stub.py: end-to-end scaffold for loss computation.
|
||||||
- train.py: minimal training loop with checkpoints.
|
- train.py: minimal training loop with checkpoints.
|
||||||
- sample.py: minimal sampling loop.
|
- sample.py: minimal sampling loop.
|
||||||
|
- export_samples.py: sample + export to CSV with original column names.
|
||||||
|
- evaluate_generated.py: basic eval of generated CSV vs training stats.
|
||||||
|
- config.json: training defaults for train.py.
|
||||||
- model_design.md: step-by-step design notes.
|
- model_design.md: step-by-step design notes.
|
||||||
- results/feature_split.txt: comma-separated feature lists.
|
- results/feature_split.txt: comma-separated feature lists.
|
||||||
- results/summary.txt: basic stats (rows sampled, column counts).
|
- results/summary.txt: basic stats (rows sampled, column counts).
|
||||||
|
|
||||||
## Run
|
## Run
|
||||||
```
|
```
|
||||||
python /home/anay/Dev/diffusion/mask-ddpm/example/analyze_hai21_03.py
|
python example/analyze_hai21_03.py
|
||||||
```
|
```
|
||||||
|
|
||||||
Prepare vocab + stats (writes to `example/results`):
|
Prepare vocab + stats (writes to `example/results`):
|
||||||
```
|
```
|
||||||
python /home/anay/Dev/diffusion/mask-ddpm/example/prepare_data.py
|
python example/prepare_data.py
|
||||||
```
|
```
|
||||||
|
|
||||||
Train a small run:
|
Train a small run:
|
||||||
```
|
```
|
||||||
python /home/anay/Dev/diffusion/mask-ddpm/example/train.py
|
python example/train.py --config example/config.json
|
||||||
```
|
```
|
||||||
|
|
||||||
Sample from the trained model:
|
Sample from the trained model:
|
||||||
```
|
```
|
||||||
python /home/anay/Dev/diffusion/mask-ddpm/example/sample.py
|
python example/sample.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Sample and export CSV:
|
||||||
|
```
|
||||||
|
python example/export_samples.py --include-time --device cpu
|
||||||
|
```
|
||||||
|
|
||||||
|
Evaluate generated CSV (writes eval.json):
|
||||||
|
```
|
||||||
|
python example/evaluate_generated.py
|
||||||
|
```
|
||||||
|
|
||||||
|
One-click pipeline (prepare -> train -> export -> eval -> plot):
|
||||||
|
```
|
||||||
|
python example/run_pipeline.py --device auto
|
||||||
```
|
```
|
||||||
|
|
||||||
## Notes
|
## Notes
|
||||||
- Heuristic: integer-like values with low cardinality (<=10) are treated as
|
- Heuristic: integer-like values with low cardinality (<=10) are treated as
|
||||||
discrete. All other numeric columns are continuous.
|
discrete. All other numeric columns are continuous.
|
||||||
|
- Set `device` in `example/config.json` to `auto` or `cuda` when moving to a GPU machine.
|
||||||
|
- Attack label columns (`attack*`) are excluded from training and generation.
|
||||||
|
- `time` column is always excluded from training and generation (optional for export only).
|
||||||
- The script only samples the first 5000 rows to stay fast.
|
- The script only samples the first 5000 rows to stay fast.
|
||||||
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.
|
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.
|
||||||
- `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU.
|
- `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU.
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
@@ -8,9 +8,12 @@ Everything else numeric -> continuous. Non-numeric -> discrete.
|
|||||||
import csv
|
import csv
|
||||||
import gzip
|
import gzip
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
DATA_PATH = "/home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz"
|
BASE_DIR = Path(__file__).resolve().parent
|
||||||
OUT_DIR = "/home/anay/Dev/diffusion/mask-ddpm/example/results"
|
REPO_DIR = BASE_DIR.parent.parent
|
||||||
|
DATA_PATH = str(REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz")
|
||||||
|
OUT_DIR = str(BASE_DIR / "results")
|
||||||
MAX_ROWS = 5000
|
MAX_ROWS = 5000
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
18
example/config.json
Normal file
18
example/config.json
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
{
|
||||||
|
"data_path": "../../dataset/hai/hai-21.03/train1.csv.gz",
|
||||||
|
"split_path": "./feature_split.json",
|
||||||
|
"stats_path": "./results/cont_stats.json",
|
||||||
|
"vocab_path": "./results/disc_vocab.json",
|
||||||
|
"out_dir": "./results",
|
||||||
|
"device": "auto",
|
||||||
|
"timesteps": 400,
|
||||||
|
"batch_size": 64,
|
||||||
|
"seq_len": 128,
|
||||||
|
"epochs": 5,
|
||||||
|
"max_batches": 2000,
|
||||||
|
"lambda": 0.5,
|
||||||
|
"lr": 0.0005,
|
||||||
|
"seed": 1337,
|
||||||
|
"log_every": 10,
|
||||||
|
"ckpt_every": 50
|
||||||
|
}
|
||||||
@@ -66,6 +66,8 @@ def build_vocab(
|
|||||||
vocab = {}
|
vocab = {}
|
||||||
for c in disc_cols:
|
for c in disc_cols:
|
||||||
tokens = sorted(values[c])
|
tokens = sorted(values[c])
|
||||||
|
if "<UNK>" not in tokens:
|
||||||
|
tokens.append("<UNK>")
|
||||||
vocab[c] = {tok: idx for idx, tok in enumerate(tokens)}
|
vocab[c] = {tok: idx for idx, tok in enumerate(tokens)}
|
||||||
return vocab
|
return vocab
|
||||||
|
|
||||||
@@ -105,7 +107,7 @@ def windowed_batches(
|
|||||||
batches_yielded = 0
|
batches_yielded = 0
|
||||||
for row in iter_rows(path):
|
for row in iter_rows(path):
|
||||||
cont_row = [float(row[c]) for c in cont_cols]
|
cont_row = [float(row[c]) for c in cont_cols]
|
||||||
disc_row = [vocab[c][row[c]] for c in disc_cols]
|
disc_row = [vocab[c].get(row[c], vocab[c]["<UNK>"]) for c in disc_cols]
|
||||||
seq_cont.append(cont_row)
|
seq_cont.append(cont_row)
|
||||||
seq_disc.append(disc_row)
|
seq_disc.append(disc_row)
|
||||||
if len(seq_cont) == seq_len:
|
if len(seq_cont) == seq_len:
|
||||||
|
|||||||
116
example/evaluate_generated.py
Normal file
116
example/evaluate_generated.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Evaluate generated samples against simple stats and vocab validity."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import gzip
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
def load_json(path: str) -> Dict:
|
||||||
|
with open(path, "r", encoding="ascii") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def open_csv(path: str):
|
||||||
|
if path.endswith(".gz"):
|
||||||
|
return gzip.open(path, "rt", newline="")
|
||||||
|
return open(path, "r", newline="")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Evaluate generated CSV samples.")
|
||||||
|
base_dir = Path(__file__).resolve().parent
|
||||||
|
parser.add_argument("--generated", default=str(base_dir / "results" / "generated.csv"))
|
||||||
|
parser.add_argument("--split", default=str(base_dir / "feature_split.json"))
|
||||||
|
parser.add_argument("--stats", default=str(base_dir / "results" / "cont_stats.json"))
|
||||||
|
parser.add_argument("--vocab", default=str(base_dir / "results" / "disc_vocab.json"))
|
||||||
|
parser.add_argument("--out", default=str(base_dir / "results" / "eval.json"))
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def init_stats(cols):
|
||||||
|
return {c: {"count": 0, "mean": 0.0, "m2": 0.0} for c in cols}
|
||||||
|
|
||||||
|
|
||||||
|
def update_stats(stats, col, value):
|
||||||
|
st = stats[col]
|
||||||
|
st["count"] += 1
|
||||||
|
delta = value - st["mean"]
|
||||||
|
st["mean"] += delta / st["count"]
|
||||||
|
delta2 = value - st["mean"]
|
||||||
|
st["m2"] += delta * delta2
|
||||||
|
|
||||||
|
|
||||||
|
def finalize_stats(stats):
|
||||||
|
out = {}
|
||||||
|
for c, st in stats.items():
|
||||||
|
if st["count"] > 1:
|
||||||
|
var = st["m2"] / (st["count"] - 1)
|
||||||
|
else:
|
||||||
|
var = 0.0
|
||||||
|
out[c] = {"mean": st["mean"], "std": var ** 0.5}
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
split = load_json(args.split)
|
||||||
|
time_col = split.get("time_column", "time")
|
||||||
|
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||||
|
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
||||||
|
|
||||||
|
stats_ref = load_json(args.stats)["mean"]
|
||||||
|
std_ref = load_json(args.stats)["std"]
|
||||||
|
vocab = load_json(args.vocab)["vocab"]
|
||||||
|
vocab_sets = {c: set(vocab[c].keys()) for c in disc_cols}
|
||||||
|
|
||||||
|
cont_stats = init_stats(cont_cols)
|
||||||
|
disc_invalid = {c: 0 for c in disc_cols}
|
||||||
|
rows = 0
|
||||||
|
|
||||||
|
with open_csv(args.generated) as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
rows += 1
|
||||||
|
if time_col in row:
|
||||||
|
row.pop(time_col, None)
|
||||||
|
for c in cont_cols:
|
||||||
|
try:
|
||||||
|
v = float(row[c])
|
||||||
|
except Exception:
|
||||||
|
v = 0.0
|
||||||
|
update_stats(cont_stats, c, v)
|
||||||
|
for c in disc_cols:
|
||||||
|
if row[c] not in vocab_sets[c]:
|
||||||
|
disc_invalid[c] += 1
|
||||||
|
|
||||||
|
cont_summary = finalize_stats(cont_stats)
|
||||||
|
cont_err = {}
|
||||||
|
for c in cont_cols:
|
||||||
|
ref_mean = float(stats_ref[c])
|
||||||
|
ref_std = float(std_ref[c]) if float(std_ref[c]) != 0 else 1.0
|
||||||
|
gen_mean = cont_summary[c]["mean"]
|
||||||
|
gen_std = cont_summary[c]["std"]
|
||||||
|
cont_err[c] = {
|
||||||
|
"mean_abs_err": abs(gen_mean - ref_mean),
|
||||||
|
"std_abs_err": abs(gen_std - ref_std),
|
||||||
|
}
|
||||||
|
|
||||||
|
report = {
|
||||||
|
"rows": rows,
|
||||||
|
"continuous_summary": cont_summary,
|
||||||
|
"continuous_error": cont_err,
|
||||||
|
"discrete_invalid_counts": disc_invalid,
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(args.out, "w", encoding="ascii") as f:
|
||||||
|
json.dump(report, f, indent=2)
|
||||||
|
|
||||||
|
print("eval_report", args.out)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
183
example/export_samples.py
Normal file
183
example/export_samples.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Sample from a trained hybrid diffusion model and export to CSV."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import gzip
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from data_utils import load_split
|
||||||
|
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
|
||||||
|
|
||||||
|
|
||||||
|
def load_vocab(path: str) -> Dict[str, Dict[str, int]]:
|
||||||
|
with open(path, "r", encoding="ascii") as f:
|
||||||
|
return json.load(f)["vocab"]
|
||||||
|
|
||||||
|
|
||||||
|
def load_stats(path: str):
|
||||||
|
with open(path, "r", encoding="ascii") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def read_header(path: str) -> List[str]:
|
||||||
|
if path.endswith(".gz"):
|
||||||
|
opener = gzip.open
|
||||||
|
mode = "rt"
|
||||||
|
else:
|
||||||
|
opener = open
|
||||||
|
mode = "r"
|
||||||
|
with opener(path, mode, newline="") as f:
|
||||||
|
reader = csv.reader(f)
|
||||||
|
return next(reader)
|
||||||
|
|
||||||
|
|
||||||
|
def build_inverse_vocab(vocab: Dict[str, Dict[str, int]]) -> Dict[str, List[str]]:
|
||||||
|
inv = {}
|
||||||
|
for col, mapping in vocab.items():
|
||||||
|
inverse = [""] * len(mapping)
|
||||||
|
for tok, idx in mapping.items():
|
||||||
|
inverse[idx] = tok
|
||||||
|
inv[col] = inverse
|
||||||
|
return inv
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Sample and export HAI feature sequences.")
|
||||||
|
base_dir = Path(__file__).resolve().parent
|
||||||
|
repo_dir = base_dir.parent.parent
|
||||||
|
parser.add_argument("--data-path", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"))
|
||||||
|
parser.add_argument("--split-path", default=str(base_dir / "feature_split.json"))
|
||||||
|
parser.add_argument("--stats-path", default=str(base_dir / "results" / "cont_stats.json"))
|
||||||
|
parser.add_argument("--vocab-path", default=str(base_dir / "results" / "disc_vocab.json"))
|
||||||
|
parser.add_argument("--model-path", default=str(base_dir / "results" / "model.pt"))
|
||||||
|
parser.add_argument("--out", default=str(base_dir / "results" / "generated.csv"))
|
||||||
|
parser.add_argument("--timesteps", type=int, default=200)
|
||||||
|
parser.add_argument("--seq-len", type=int, default=64)
|
||||||
|
parser.add_argument("--batch-size", type=int, default=2)
|
||||||
|
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
|
||||||
|
parser.add_argument("--include-time", action="store_true", help="Include time column as a simple index")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_device(mode: str) -> str:
|
||||||
|
mode = mode.lower()
|
||||||
|
if mode == "cpu":
|
||||||
|
return "cpu"
|
||||||
|
if mode == "cuda":
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise SystemExit("device set to cuda but CUDA is not available")
|
||||||
|
return "cuda"
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return "cuda"
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
if not os.path.exists(args.model_path):
|
||||||
|
raise SystemExit("missing model file: %s" % args.model_path)
|
||||||
|
|
||||||
|
split = load_split(args.split_path)
|
||||||
|
time_col = split.get("time_column", "time")
|
||||||
|
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||||
|
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
||||||
|
|
||||||
|
stats = load_stats(args.stats_path)
|
||||||
|
mean = stats["mean"]
|
||||||
|
std = stats["std"]
|
||||||
|
|
||||||
|
vocab = load_vocab(args.vocab_path)
|
||||||
|
inv_vocab = build_inverse_vocab(vocab)
|
||||||
|
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||||
|
|
||||||
|
device = resolve_device(args.device)
|
||||||
|
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(device)
|
||||||
|
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
betas = cosine_beta_schedule(args.timesteps).to(device)
|
||||||
|
alphas = 1.0 - betas
|
||||||
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||||
|
|
||||||
|
x_cont = torch.randn(args.batch_size, args.seq_len, len(cont_cols), device=device)
|
||||||
|
x_disc = torch.full(
|
||||||
|
(args.batch_size, args.seq_len, len(disc_cols)),
|
||||||
|
0,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
||||||
|
for i in range(len(disc_cols)):
|
||||||
|
x_disc[:, :, i] = mask_tokens[i]
|
||||||
|
|
||||||
|
for t in reversed(range(args.timesteps)):
|
||||||
|
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
|
||||||
|
eps_pred, logits = model(x_cont, x_disc, t_batch)
|
||||||
|
|
||||||
|
a_t = alphas[t]
|
||||||
|
a_bar_t = alphas_cumprod[t]
|
||||||
|
coef1 = 1.0 / torch.sqrt(a_t)
|
||||||
|
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
|
||||||
|
mean_x = coef1 * (x_cont - coef2 * eps_pred)
|
||||||
|
if t > 0:
|
||||||
|
noise = torch.randn_like(x_cont)
|
||||||
|
x_cont = mean_x + torch.sqrt(betas[t]) * noise
|
||||||
|
else:
|
||||||
|
x_cont = mean_x
|
||||||
|
|
||||||
|
for i, logit in enumerate(logits):
|
||||||
|
if t == 0:
|
||||||
|
probs = F.softmax(logit, dim=-1)
|
||||||
|
x_disc[:, :, i] = torch.argmax(probs, dim=-1)
|
||||||
|
else:
|
||||||
|
mask = x_disc[:, :, i] == mask_tokens[i]
|
||||||
|
if mask.any():
|
||||||
|
probs = F.softmax(logit, dim=-1)
|
||||||
|
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(
|
||||||
|
args.batch_size, args.seq_len
|
||||||
|
)
|
||||||
|
x_disc[:, :, i][mask] = sampled[mask]
|
||||||
|
|
||||||
|
x_cont = x_cont.cpu()
|
||||||
|
x_disc = x_disc.cpu()
|
||||||
|
|
||||||
|
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||||
|
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||||
|
x_cont = x_cont * std_vec + mean_vec
|
||||||
|
|
||||||
|
header = read_header(args.data_path)
|
||||||
|
out_cols = [c for c in header if c != time_col or args.include_time]
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(args.out), exist_ok=True)
|
||||||
|
with open(args.out, "w", newline="", encoding="ascii") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=out_cols)
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
row_index = 0
|
||||||
|
for b in range(args.batch_size):
|
||||||
|
for t in range(args.seq_len):
|
||||||
|
row = {}
|
||||||
|
if args.include_time and time_col in header:
|
||||||
|
row[time_col] = str(row_index)
|
||||||
|
for i, c in enumerate(cont_cols):
|
||||||
|
row[c] = ("%.6f" % float(x_cont[b, t, i]))
|
||||||
|
for i, c in enumerate(disc_cols):
|
||||||
|
tok_idx = int(x_disc[b, t, i])
|
||||||
|
tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "0"
|
||||||
|
row[c] = tok
|
||||||
|
writer.writerow(row)
|
||||||
|
row_index += 1
|
||||||
|
|
||||||
|
print("exported_csv", args.out)
|
||||||
|
print("rows", args.batch_size * args.seq_len)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
63
example/plot_loss.py
Normal file
63
example/plot_loss.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Plot training loss curves from train_log.csv."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Plot loss curves from train_log.csv")
|
||||||
|
base_dir = Path(__file__).resolve().parent
|
||||||
|
parser.add_argument(
|
||||||
|
"--log",
|
||||||
|
default=str(base_dir / "results" / "train_log.csv"),
|
||||||
|
help="Path to train_log.csv",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--out",
|
||||||
|
default=str(base_dir / "results" / "train_loss.png"),
|
||||||
|
help="Output PNG path",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
log_path = Path(args.log)
|
||||||
|
if not log_path.exists():
|
||||||
|
raise SystemExit("missing log file: %s" % log_path)
|
||||||
|
|
||||||
|
steps = []
|
||||||
|
loss = []
|
||||||
|
loss_cont = []
|
||||||
|
loss_disc = []
|
||||||
|
|
||||||
|
with log_path.open("r", encoding="ascii", newline="") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
steps.append(int(row["step"]))
|
||||||
|
loss.append(float(row["loss"]))
|
||||||
|
loss_cont.append(float(row["loss_cont"]))
|
||||||
|
loss_disc.append(float(row["loss_disc"]))
|
||||||
|
|
||||||
|
if not steps:
|
||||||
|
raise SystemExit("no rows in log file: %s" % log_path)
|
||||||
|
|
||||||
|
plt.figure(figsize=(8, 5))
|
||||||
|
plt.plot(steps, loss, label="total")
|
||||||
|
plt.plot(steps, loss_cont, label="continuous")
|
||||||
|
plt.plot(steps, loss_disc, label="discrete")
|
||||||
|
plt.xlabel("step")
|
||||||
|
plt.ylabel("loss")
|
||||||
|
plt.title("Training Loss")
|
||||||
|
plt.legend()
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(args.out, dpi=150)
|
||||||
|
print("saved", args.out)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -2,20 +2,24 @@
|
|||||||
"""Prepare vocab and normalization stats for HAI 21.03."""
|
"""Prepare vocab and normalization stats for HAI 21.03."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from data_utils import compute_cont_stats, build_vocab, load_split
|
from data_utils import compute_cont_stats, build_vocab, load_split
|
||||||
|
|
||||||
DATA_PATH = "/home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz"
|
BASE_DIR = Path(__file__).resolve().parent
|
||||||
SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json"
|
REPO_DIR = BASE_DIR.parent.parent
|
||||||
OUT_STATS = "/home/anay/Dev/diffusion/mask-ddpm/example/results/cont_stats.json"
|
DATA_PATH = str(REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz")
|
||||||
OUT_VOCAB = "/home/anay/Dev/diffusion/mask-ddpm/example/results/disc_vocab.json"
|
SPLIT_PATH = str(BASE_DIR / "feature_split.json")
|
||||||
|
OUT_STATS = str(BASE_DIR / "results" / "cont_stats.json")
|
||||||
|
OUT_VOCAB = str(BASE_DIR / "results" / "disc_vocab.json")
|
||||||
|
|
||||||
|
|
||||||
def main(max_rows: Optional[int] = None):
|
def main(max_rows: Optional[int] = None):
|
||||||
split = load_split(SPLIT_PATH)
|
split = load_split(SPLIT_PATH)
|
||||||
cont_cols = split["continuous"]
|
time_col = split.get("time_column", "time")
|
||||||
disc_cols = split["discrete"]
|
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||||
|
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
||||||
|
|
||||||
mean, std = compute_cont_stats(DATA_PATH, cont_cols, max_rows=max_rows)
|
mean, std = compute_cont_stats(DATA_PATH, cont_cols, max_rows=max_rows)
|
||||||
vocab = build_vocab(DATA_PATH, disc_cols, max_rows=max_rows)
|
vocab = build_vocab(DATA_PATH, disc_cols, max_rows=max_rows)
|
||||||
|
|||||||
@@ -1,113 +0,0 @@
|
|||||||
{
|
|
||||||
"mean": {
|
|
||||||
"P1_B2004": 0.08649086820000026,
|
|
||||||
"P1_B2016": 1.376161456000001,
|
|
||||||
"P1_B3004": 396.1861596906018,
|
|
||||||
"P1_B3005": 1037.372384413793,
|
|
||||||
"P1_B4002": 32.564872940799994,
|
|
||||||
"P1_B4005": 65.98190757240047,
|
|
||||||
"P1_B400B": 1925.0391570245934,
|
|
||||||
"P1_B4022": 36.28908066800001,
|
|
||||||
"P1_FCV02Z": 21.744261118400036,
|
|
||||||
"P1_FCV03D": 57.36123274140044,
|
|
||||||
"P1_FCV03Z": 58.05084519640002,
|
|
||||||
"P1_FT01": 184.18615112319728,
|
|
||||||
"P1_FT01Z": 851.8781750705965,
|
|
||||||
"P1_FT02": 1255.8572173544069,
|
|
||||||
"P1_FT02Z": 1925.0210755194114,
|
|
||||||
"P1_FT03": 269.37285885780574,
|
|
||||||
"P1_FT03Z": 1037.366172230601,
|
|
||||||
"P1_LCV01D": 11.228849048599963,
|
|
||||||
"P1_LCV01Z": 10.991610181600016,
|
|
||||||
"P1_LIT01": 396.8845311109994,
|
|
||||||
"P1_PCV01D": 53.80101618419986,
|
|
||||||
"P1_PCV01Z": 54.646640287199595,
|
|
||||||
"P1_PCV02Z": 12.017773542800072,
|
|
||||||
"P1_PIT01": 1.3692859488000075,
|
|
||||||
"P1_PIT02": 0.44459071260000227,
|
|
||||||
"P1_TIT01": 35.64255813999988,
|
|
||||||
"P1_TIT02": 36.44807823060023,
|
|
||||||
"P2_24Vdc": 28.0280019013999,
|
|
||||||
"P2_CO_rpm": 54105.64434999997,
|
|
||||||
"P2_HILout": 712.0588667425922,
|
|
||||||
"P2_MSD": 763.19324,
|
|
||||||
"P2_SIT01": 778.7769850000013,
|
|
||||||
"P2_SIT02": 778.7778935471981,
|
|
||||||
"P2_VT01": 11.914949448200044,
|
|
||||||
"P2_VXT02": -3.5267871940000175,
|
|
||||||
"P2_VXT03": -1.5520904921999914,
|
|
||||||
"P2_VYT02": 3.796112737600002,
|
|
||||||
"P2_VYT03": 6.121691697000018,
|
|
||||||
"P3_FIT01": 1168.2528800000014,
|
|
||||||
"P3_LCP01D": 4675.465239999989,
|
|
||||||
"P3_LCV01D": 7445.208720000017,
|
|
||||||
"P3_LIT01": 13728.982314999852,
|
|
||||||
"P3_PIT01": 668.9722350000003,
|
|
||||||
"P4_HT_FD": -0.00010012580000000082,
|
|
||||||
"P4_HT_LD": 35.41945000099953,
|
|
||||||
"P4_HT_PO": 35.4085699912002,
|
|
||||||
"P4_LD": 365.3833745803986,
|
|
||||||
"P4_ST_FD": -6.5205999999999635e-06,
|
|
||||||
"P4_ST_GOV": 17801.81294499996,
|
|
||||||
"P4_ST_LD": 329.83259218199964,
|
|
||||||
"P4_ST_PO": 330.1079461497967,
|
|
||||||
"P4_ST_PT01": 10047.679605000127,
|
|
||||||
"P4_ST_TT01": 27606.860070000155
|
|
||||||
},
|
|
||||||
"std": {
|
|
||||||
"P1_B2004": 0.024492489898690458,
|
|
||||||
"P1_B2016": 0.12949272564759745,
|
|
||||||
"P1_B3004": 10.16264800653289,
|
|
||||||
"P1_B3005": 70.85697659109,
|
|
||||||
"P1_B4002": 0.7578213113008356,
|
|
||||||
"P1_B4005": 41.80065314991797,
|
|
||||||
"P1_B400B": 1176.6445547448632,
|
|
||||||
"P1_B4022": 0.8221115066487089,
|
|
||||||
"P1_FCV02Z": 39.11843197764176,
|
|
||||||
"P1_FCV03D": 7.889507447726624,
|
|
||||||
"P1_FCV03Z": 8.046068905945717,
|
|
||||||
"P1_FT01": 30.80117031882856,
|
|
||||||
"P1_FT01Z": 91.2786865433318,
|
|
||||||
"P1_FT02": 879.7163277334494,
|
|
||||||
"P1_FT02Z": 1176.6699531305114,
|
|
||||||
"P1_FT03": 38.18015841964941,
|
|
||||||
"P1_FT03Z": 70.73100774436428,
|
|
||||||
"P1_LCV01D": 3.3355655415557597,
|
|
||||||
"P1_LCV01Z": 3.386332233773545,
|
|
||||||
"P1_LIT01": 10.57871476010412,
|
|
||||||
"P1_PCV01D": 19.61567943613885,
|
|
||||||
"P1_PCV01Z": 19.778754467302086,
|
|
||||||
"P1_PCV02Z": 0.004804797893159998,
|
|
||||||
"P1_PIT01": 0.0776614954053113,
|
|
||||||
"P1_PIT02": 0.44823231815652304,
|
|
||||||
"P1_TIT01": 0.5986678527528815,
|
|
||||||
"P1_TIT02": 1.1892341204521049,
|
|
||||||
"P2_24Vdc": 0.00320884250409781,
|
|
||||||
"P2_CO_rpm": 20.57547782150726,
|
|
||||||
"P2_HILout": 8.17885337990861,
|
|
||||||
"P2_MSD": 1.0,
|
|
||||||
"P2_SIT01": 3.894535775667256,
|
|
||||||
"P2_SIT02": 3.882477078857941,
|
|
||||||
"P2_VT01": 0.06812990916670243,
|
|
||||||
"P2_VXT02": 0.43104157117568803,
|
|
||||||
"P2_VXT03": 0.26894251958139775,
|
|
||||||
"P2_VYT02": 0.46109078832075856,
|
|
||||||
"P2_VYT03": 0.3059642938507547,
|
|
||||||
"P3_FIT01": 1787.2987693141868,
|
|
||||||
"P3_LCP01D": 5145.4094261812725,
|
|
||||||
"P3_LCV01D": 6785.602781765096,
|
|
||||||
"P3_LIT01": 4060.915441872745,
|
|
||||||
"P3_PIT01": 1168.1071264424027,
|
|
||||||
"P4_HT_FD": 0.002032582380617592,
|
|
||||||
"P4_HT_LD": 33.212361169253235,
|
|
||||||
"P4_HT_PO": 31.187825914515162,
|
|
||||||
"P4_LD": 59.736616589045646,
|
|
||||||
"P4_ST_FD": 0.0016428787127432496,
|
|
||||||
"P4_ST_GOV": 1740.5997458128215,
|
|
||||||
"P4_ST_LD": 35.86633288900077,
|
|
||||||
"P4_ST_PO": 32.375012735256696,
|
|
||||||
"P4_ST_PT01": 22.459962818146252,
|
|
||||||
"P4_ST_TT01": 24.745939350221477
|
|
||||||
},
|
|
||||||
"max_rows": 50000
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +0,0 @@
|
|||||||
discrete
|
|
||||||
P1_FCV01D,P1_FCV01Z,P1_FCV02D,P1_PCV02D,P1_PP01AD,P1_PP01AR,P1_PP01BD,P1_PP01BR,P1_PP02D,P1_PP02R,P1_STSP,P2_ASD,P2_AutoGO,P2_Emerg,P2_ManualGO,P2_OnOff,P2_RTR,P2_TripEx,P2_VTR01,P2_VTR02,P2_VTR03,P2_VTR04,P3_LH,P3_LL,P4_HT_PS,P4_ST_PS,attack,attack_P1,attack_P2,attack_P3
|
|
||||||
continuous
|
|
||||||
P1_B2004,P1_B2016,P1_B3004,P1_B3005,P1_B4002,P1_B4005,P1_B400B,P1_B4022,P1_FCV02Z,P1_FCV03D,P1_FCV03Z,P1_FT01,P1_FT01Z,P1_FT02,P1_FT02Z,P1_FT03,P1_FT03Z,P1_LCV01D,P1_LCV01Z,P1_LIT01,P1_PCV01D,P1_PCV01Z,P1_PCV02Z,P1_PIT01,P1_PIT02,P1_TIT01,P1_TIT02,P2_24Vdc,P2_CO_rpm,P2_HILout,P2_MSD,P2_SIT01,P2_SIT02,P2_VT01,P2_VXT02,P2_VXT03,P2_VYT02,P2_VYT03,P3_FIT01,P3_LCP01D,P3_LCV01D,P3_LIT01,P3_PIT01,P4_HT_FD,P4_HT_LD,P4_HT_PO,P4_LD,P4_ST_FD,P4_ST_GOV,P4_ST_LD,P4_ST_PO,P4_ST_PT01,P4_ST_TT01
|
|
||||||
Binary file not shown.
@@ -1,6 +0,0 @@
|
|||||||
rows_sampled: 5000
|
|
||||||
columns_total: 84
|
|
||||||
continuous: 53
|
|
||||||
discrete: 30
|
|
||||||
unknown: 0
|
|
||||||
data_path: /home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz
|
|
||||||
50
example/run_pipeline.py
Normal file
50
example/run_pipeline.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""One-click pipeline: prepare -> train -> export -> evaluate -> plot loss."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def run(cmd):
|
||||||
|
print("running:", " ".join(cmd))
|
||||||
|
subprocess.run(cmd, check=True)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Run full HAI pipeline.")
|
||||||
|
base_dir = Path(__file__).resolve().parent
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
default=str(base_dir / "config.json"),
|
||||||
|
help="Path to training config JSON",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
default="auto",
|
||||||
|
help="cpu, cuda, or auto (used for export_samples.py)",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
base_dir = Path(__file__).resolve().parent
|
||||||
|
run([sys.executable, str(base_dir / "prepare_data.py")])
|
||||||
|
run([sys.executable, str(base_dir / "train.py"), "--config", args.config])
|
||||||
|
run(
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
str(base_dir / "export_samples.py"),
|
||||||
|
"--include-time",
|
||||||
|
"--device",
|
||||||
|
args.device,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
run([sys.executable, str(base_dir / "evaluate_generated.py")])
|
||||||
|
run([sys.executable, str(base_dir / "plot_loss.py")])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -4,6 +4,7 @@
|
|||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -11,11 +12,25 @@ import torch.nn.functional as F
|
|||||||
from data_utils import load_split
|
from data_utils import load_split
|
||||||
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
|
from hybrid_diffusion import HybridDiffusionModel, cosine_beta_schedule
|
||||||
|
|
||||||
SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json"
|
BASE_DIR = Path(__file__).resolve().parent
|
||||||
VOCAB_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/disc_vocab.json"
|
SPLIT_PATH = str(BASE_DIR / "feature_split.json")
|
||||||
MODEL_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/model.pt"
|
VOCAB_PATH = str(BASE_DIR / "results" / "disc_vocab.json")
|
||||||
|
MODEL_PATH = str(BASE_DIR / "results" / "model.pt")
|
||||||
|
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
def resolve_device(mode: str) -> str:
|
||||||
|
mode = mode.lower()
|
||||||
|
if mode == "cpu":
|
||||||
|
return "cpu"
|
||||||
|
if mode == "cuda":
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise SystemExit("device set to cuda but CUDA is not available")
|
||||||
|
return "cuda"
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return "cuda"
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
DEVICE = resolve_device("auto")
|
||||||
TIMESTEPS = 200
|
TIMESTEPS = 200
|
||||||
SEQ_LEN = 64
|
SEQ_LEN = 64
|
||||||
BATCH_SIZE = 2
|
BATCH_SIZE = 2
|
||||||
@@ -28,8 +43,9 @@ def load_vocab():
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
split = load_split(SPLIT_PATH)
|
split = load_split(SPLIT_PATH)
|
||||||
cont_cols = split["continuous"]
|
time_col = split.get("time_column", "time")
|
||||||
disc_cols = split["discrete"]
|
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||||
|
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
||||||
|
|
||||||
vocab = load_vocab()
|
vocab = load_vocab()
|
||||||
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||||
|
|||||||
163
example/train.py
163
example/train.py
@@ -1,8 +1,12 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Train hybrid diffusion on HAI 21.03 (minimal runnable example)."""
|
"""Train hybrid diffusion on HAI (configurable runnable example)."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -15,83 +19,140 @@ from hybrid_diffusion import (
|
|||||||
q_sample_discrete,
|
q_sample_discrete,
|
||||||
)
|
)
|
||||||
|
|
||||||
DATA_PATH = "/home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz"
|
|
||||||
SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json"
|
|
||||||
STATS_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/cont_stats.json"
|
|
||||||
VOCAB_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/results/disc_vocab.json"
|
|
||||||
OUT_DIR = "/home/anay/Dev/diffusion/mask-ddpm/example/results"
|
|
||||||
|
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
BASE_DIR = Path(__file__).resolve().parent
|
||||||
TIMESTEPS = 1000
|
REPO_DIR = BASE_DIR.parent.parent
|
||||||
BATCH_SIZE = 8
|
|
||||||
SEQ_LEN = 64
|
DEFAULTS = {
|
||||||
EPOCHS = 1
|
"data_path": str(REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"),
|
||||||
MAX_BATCHES = 50
|
"split_path": str(BASE_DIR / "feature_split.json"),
|
||||||
LAMBDA = 0.5
|
"stats_path": str(BASE_DIR / "results" / "cont_stats.json"),
|
||||||
LR = 1e-3
|
"vocab_path": str(BASE_DIR / "results" / "disc_vocab.json"),
|
||||||
|
"out_dir": str(BASE_DIR / "results"),
|
||||||
|
"device": "auto",
|
||||||
|
"timesteps": 1000,
|
||||||
|
"batch_size": 8,
|
||||||
|
"seq_len": 64,
|
||||||
|
"epochs": 1,
|
||||||
|
"max_batches": 50,
|
||||||
|
"lambda": 0.5,
|
||||||
|
"lr": 1e-3,
|
||||||
|
"seed": 1337,
|
||||||
|
"log_every": 10,
|
||||||
|
"ckpt_every": 50,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_stats():
|
def load_json(path: str) -> Dict:
|
||||||
with open(STATS_PATH, "r", encoding="ascii") as f:
|
with open(path, "r", encoding="ascii") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
def load_vocab():
|
def set_seed(seed: int):
|
||||||
with open(VOCAB_PATH, "r", encoding="ascii") as f:
|
random.seed(seed)
|
||||||
return json.load(f)["vocab"]
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_device(mode: str) -> str:
|
||||||
|
mode = mode.lower()
|
||||||
|
if mode == "cpu":
|
||||||
|
return "cpu"
|
||||||
|
if mode == "cuda":
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise SystemExit("device set to cuda but CUDA is not available")
|
||||||
|
return "cuda"
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return "cuda"
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Train hybrid diffusion on HAI.")
|
||||||
|
parser.add_argument("--config", default=None, help="Path to JSON config.")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_config_paths(config, base_dir: Path):
|
||||||
|
keys = ["data_path", "split_path", "stats_path", "vocab_path", "out_dir"]
|
||||||
|
for key in keys:
|
||||||
|
if key in config:
|
||||||
|
path = Path(str(config[key]))
|
||||||
|
if not path.is_absolute():
|
||||||
|
config[key] = str((base_dir / path).resolve())
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
split = load_split(SPLIT_PATH)
|
args = parse_args()
|
||||||
cont_cols = split["continuous"]
|
config = dict(DEFAULTS)
|
||||||
disc_cols = split["discrete"]
|
if args.config:
|
||||||
|
cfg_path = Path(args.config).resolve()
|
||||||
|
config.update(load_json(str(cfg_path)))
|
||||||
|
config = resolve_config_paths(config, cfg_path.parent)
|
||||||
|
else:
|
||||||
|
config = resolve_config_paths(config, BASE_DIR)
|
||||||
|
|
||||||
stats = load_stats()
|
set_seed(int(config["seed"]))
|
||||||
|
|
||||||
|
split = load_split(config["split_path"])
|
||||||
|
time_col = split.get("time_column", "time")
|
||||||
|
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||||
|
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
||||||
|
|
||||||
|
stats = load_json(config["stats_path"])
|
||||||
mean = stats["mean"]
|
mean = stats["mean"]
|
||||||
std = stats["std"]
|
std = stats["std"]
|
||||||
vocab = load_vocab()
|
|
||||||
|
|
||||||
|
vocab = load_json(config["vocab_path"])["vocab"]
|
||||||
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||||
|
|
||||||
print("device", DEVICE)
|
device = resolve_device(str(config["device"]))
|
||||||
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE)
|
print("device", device)
|
||||||
opt = torch.optim.Adam(model.parameters(), lr=LR)
|
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(device)
|
||||||
|
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
|
||||||
|
|
||||||
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)
|
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
|
||||||
alphas = 1.0 - betas
|
alphas = 1.0 - betas
|
||||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||||
|
|
||||||
os.makedirs(OUT_DIR, exist_ok=True)
|
os.makedirs(config["out_dir"], exist_ok=True)
|
||||||
|
log_path = os.path.join(config["out_dir"], "train_log.csv")
|
||||||
|
with open(log_path, "w", encoding="ascii") as f:
|
||||||
|
f.write("epoch,step,loss,loss_cont,loss_disc\n")
|
||||||
|
|
||||||
for epoch in range(EPOCHS):
|
total_step = 0
|
||||||
|
for epoch in range(int(config["epochs"])):
|
||||||
for step, (x_cont, x_disc) in enumerate(
|
for step, (x_cont, x_disc) in enumerate(
|
||||||
windowed_batches(
|
windowed_batches(
|
||||||
DATA_PATH,
|
config["data_path"],
|
||||||
cont_cols,
|
cont_cols,
|
||||||
disc_cols,
|
disc_cols,
|
||||||
vocab,
|
vocab,
|
||||||
mean,
|
mean,
|
||||||
std,
|
std,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=int(config["batch_size"]),
|
||||||
seq_len=SEQ_LEN,
|
seq_len=int(config["seq_len"]),
|
||||||
max_batches=MAX_BATCHES,
|
max_batches=int(config["max_batches"]),
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
x_cont = x_cont.to(DEVICE)
|
x_cont = x_cont.to(device)
|
||||||
x_disc = x_disc.to(DEVICE)
|
x_disc = x_disc.to(device)
|
||||||
|
|
||||||
bsz = x_cont.size(0)
|
bsz = x_cont.size(0)
|
||||||
t = torch.randint(0, TIMESTEPS, (bsz,), device=DEVICE)
|
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
|
||||||
|
|
||||||
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
|
x_cont_t, noise = q_sample_continuous(x_cont, t, alphas_cumprod)
|
||||||
|
|
||||||
mask_tokens = torch.tensor(vocab_sizes, device=DEVICE)
|
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
||||||
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, TIMESTEPS)
|
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"]))
|
||||||
|
|
||||||
eps_pred, logits = model(x_cont_t, x_disc_t, t)
|
eps_pred, logits = model(x_cont_t, x_disc_t, t)
|
||||||
|
|
||||||
loss_cont = F.mse_loss(eps_pred, noise)
|
loss_cont = F.mse_loss(eps_pred, noise)
|
||||||
|
|
||||||
loss_disc = 0.0
|
loss_disc = 0.0
|
||||||
for i, logit in enumerate(logits):
|
for i, logit in enumerate(logits):
|
||||||
if mask[:, :, i].any():
|
if mask[:, :, i].any():
|
||||||
@@ -99,15 +160,31 @@ def main():
|
|||||||
logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]]
|
logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]]
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = LAMBDA * loss_cont + (1 - LAMBDA) * loss_disc
|
lam = float(config["lambda"])
|
||||||
|
loss = lam * loss_cont + (1 - lam) * loss_disc
|
||||||
opt.zero_grad()
|
opt.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
opt.step()
|
opt.step()
|
||||||
|
|
||||||
if step % 10 == 0:
|
if step % int(config["log_every"]) == 0:
|
||||||
print("epoch", epoch, "step", step, "loss", float(loss))
|
print("epoch", epoch, "step", step, "loss", float(loss))
|
||||||
|
with open(log_path, "a", encoding="ascii") as f:
|
||||||
|
f.write(
|
||||||
|
"%d,%d,%.6f,%.6f,%.6f\n"
|
||||||
|
% (epoch, step, float(loss), float(loss_cont), float(loss_disc))
|
||||||
|
)
|
||||||
|
|
||||||
torch.save(model.state_dict(), os.path.join(OUT_DIR, "model.pt"))
|
total_step += 1
|
||||||
|
if total_step % int(config["ckpt_every"]) == 0:
|
||||||
|
ckpt = {
|
||||||
|
"model": model.state_dict(),
|
||||||
|
"optim": opt.state_dict(),
|
||||||
|
"config": config,
|
||||||
|
"step": total_step,
|
||||||
|
}
|
||||||
|
torch.save(ckpt, os.path.join(config["out_dir"], "model_ckpt.pt"))
|
||||||
|
|
||||||
|
torch.save(model.state_dict(), os.path.join(config["out_dir"], "model.pt"))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import csv
|
|||||||
import gzip
|
import gzip
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -20,8 +21,10 @@ from hybrid_diffusion import (
|
|||||||
q_sample_discrete,
|
q_sample_discrete,
|
||||||
)
|
)
|
||||||
|
|
||||||
DATA_PATH = "/home/anay/Dev/diffusion/dataset/hai/hai-21.03/train1.csv.gz"
|
BASE_DIR = Path(__file__).resolve().parent
|
||||||
SPLIT_PATH = "/home/anay/Dev/diffusion/mask-ddpm/example/feature_split.json"
|
REPO_DIR = BASE_DIR.parent.parent
|
||||||
|
DATA_PATH = str(REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz")
|
||||||
|
SPLIT_PATH = str(BASE_DIR / "feature_split.json")
|
||||||
DEVICE = "cpu"
|
DEVICE = "cpu"
|
||||||
TIMESTEPS = 1000
|
TIMESTEPS = 1000
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user