This commit is contained in:
2026-01-22 20:42:10 +08:00
parent f37a8ce179
commit 382c756dfe
10 changed files with 310 additions and 55 deletions

View File

@@ -1,7 +1,7 @@
# Example: HAI 21.03 Feature Split # Example: HAI 21.03 Feature Split
This folder contains a small, reproducible example that inspects the HAI 21.03 This folder contains a small, reproducible example that inspects the HAI 21.03
CSV (train1) and produces a continuous/discrete split using a simple heuristic. CSV (all train*.csv.gz files) and produces a continuous/discrete split using a simple heuristic.
## Files ## Files
- analyze_hai21_03.py: reads a sample of the data and writes results. - analyze_hai21_03.py: reads a sample of the data and writes results.
@@ -60,6 +60,12 @@ python example/run_pipeline.py --device auto
- Set `device` in `example/config.json` to `auto` or `cuda` when moving to a GPU machine. - Set `device` in `example/config.json` to `auto` or `cuda` when moving to a GPU machine.
- Attack label columns (`attack*`) are excluded from training and generation. - Attack label columns (`attack*`) are excluded from training and generation.
- `time` column is always excluded from training and generation (optional for export only). - `time` column is always excluded from training and generation (optional for export only).
- EMA weights are saved as `model_ema.pt` and used by the pipeline for sampling.
- Gradients are clipped by default (`grad_clip` in `config.json`) to stabilize training.
- Discrete masking uses a cosine schedule for smoother corruption.
- Continuous sampling is clipped in normalized space each step for stability.
- Optional conditioning by file id (`train*.csv.gz`) is enabled by default for multi-file training.
- Continuous head can be bounded with `tanh` via `use_tanh_eps` in config.
- The script only samples the first 5000 rows to stay fast. - The script only samples the first 5000 rows to stay fast.
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it. - `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.
- `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU. - `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU.

View File

@@ -1,5 +1,6 @@
{ {
"data_path": "../../dataset/hai/hai-21.03/train1.csv.gz", "data_path": "../../dataset/hai/hai-21.03/train1.csv.gz",
"data_glob": "../../dataset/hai/hai-21.03/train*.csv.gz",
"split_path": "./feature_split.json", "split_path": "./feature_split.json",
"stats_path": "./results/cont_stats.json", "stats_path": "./results/cont_stats.json",
"vocab_path": "./results/disc_vocab.json", "vocab_path": "./results/disc_vocab.json",
@@ -14,5 +15,16 @@
"lr": 0.0005, "lr": 0.0005,
"seed": 1337, "seed": 1337,
"log_every": 10, "log_every": 10,
"ckpt_every": 50 "ckpt_every": 50,
"ema_decay": 0.999,
"use_ema": true,
"clip_k": 5.0,
"grad_clip": 1.0,
"use_condition": true,
"condition_type": "file_id",
"cond_dim": 32,
"use_tanh_eps": true,
"eps_scale": 1.0,
"sample_batch_size": 8,
"sample_seq_len": 128
} }

View File

@@ -4,7 +4,7 @@
import csv import csv
import gzip import gzip
import json import json
from typing import Dict, Iterable, List, Optional, Tuple from typing import Dict, Iterable, List, Optional, Tuple, Union
@@ -13,15 +13,18 @@ def load_split(path: str) -> Dict[str, List[str]]:
return json.load(f) return json.load(f)
def iter_rows(path: str) -> Iterable[Dict[str, str]]: def iter_rows(path_or_paths: Union[str, List[str]]) -> Iterable[Dict[str, str]]:
with gzip.open(path, "rt", newline="") as f: paths = [path_or_paths] if isinstance(path_or_paths, str) else list(path_or_paths)
reader = csv.DictReader(f) for path in paths:
for row in reader: opener = gzip.open if str(path).endswith(".gz") else open
yield row with opener(path, "rt", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
yield row
def compute_cont_stats( def compute_cont_stats(
path: str, path: Union[str, List[str]],
cont_cols: List[str], cont_cols: List[str],
max_rows: Optional[int] = None, max_rows: Optional[int] = None,
) -> Tuple[Dict[str, float], Dict[str, float]]: ) -> Tuple[Dict[str, float], Dict[str, float]]:
@@ -52,7 +55,7 @@ def compute_cont_stats(
def build_vocab( def build_vocab(
path: str, path: Union[str, List[str]],
disc_cols: List[str], disc_cols: List[str],
max_rows: Optional[int] = None, max_rows: Optional[int] = None,
) -> Dict[str, Dict[str, int]]: ) -> Dict[str, Dict[str, int]]:
@@ -80,7 +83,7 @@ def normalize_cont(x, cont_cols: List[str], mean: Dict[str, float], std: Dict[st
def windowed_batches( def windowed_batches(
path: str, path: Union[str, List[str]],
cont_cols: List[str], cont_cols: List[str],
disc_cols: List[str], disc_cols: List[str],
vocab: Dict[str, Dict[str, int]], vocab: Dict[str, Dict[str, int]],
@@ -89,10 +92,12 @@ def windowed_batches(
batch_size: int, batch_size: int,
seq_len: int, seq_len: int,
max_batches: Optional[int] = None, max_batches: Optional[int] = None,
return_file_id: bool = False,
): ):
import torch import torch
batch_cont = [] batch_cont = []
batch_disc = [] batch_disc = []
batch_file = []
seq_cont = [] seq_cont = []
seq_disc = [] seq_disc = []
@@ -105,22 +110,34 @@ def windowed_batches(
seq_disc = [] seq_disc = []
batches_yielded = 0 batches_yielded = 0
for row in iter_rows(path): paths = [path] if isinstance(path, str) else list(path)
cont_row = [float(row[c]) for c in cont_cols] for file_id, p in enumerate(paths):
disc_row = [vocab[c].get(row[c], vocab[c]["<UNK>"]) for c in disc_cols] for row in iter_rows(p):
seq_cont.append(cont_row) cont_row = [float(row[c]) for c in cont_cols]
seq_disc.append(disc_row) disc_row = [vocab[c].get(row[c], vocab[c]["<UNK>"]) for c in disc_cols]
if len(seq_cont) == seq_len: seq_cont.append(cont_row)
flush_seq() seq_disc.append(disc_row)
if len(batch_cont) == batch_size: if len(seq_cont) == seq_len:
x_cont = torch.tensor(batch_cont, dtype=torch.float32) flush_seq()
x_disc = torch.tensor(batch_disc, dtype=torch.long) if return_file_id:
x_cont = normalize_cont(x_cont, cont_cols, mean, std) batch_file.append(file_id)
yield x_cont, x_disc if len(batch_cont) == batch_size:
batch_cont = [] x_cont = torch.tensor(batch_cont, dtype=torch.float32)
batch_disc = [] x_disc = torch.tensor(batch_disc, dtype=torch.long)
batches_yielded += 1 x_cont = normalize_cont(x_cont, cont_cols, mean, std)
if max_batches is not None and batches_yielded >= max_batches: if return_file_id:
return x_file = torch.tensor(batch_file, dtype=torch.long)
yield x_cont, x_disc, x_file
else:
yield x_cont, x_disc
batch_cont = []
batch_disc = []
batch_file = []
batches_yielded += 1
if max_batches is not None and batches_yielded >= max_batches:
return
# drop partial sequence at file boundary
seq_cont = []
seq_disc = []
# Drop last partial batch for simplicity # Drop last partial batch for simplicity

View File

@@ -54,6 +54,7 @@ def parse_args():
base_dir = Path(__file__).resolve().parent base_dir = Path(__file__).resolve().parent
repo_dir = base_dir.parent.parent repo_dir = base_dir.parent.parent
parser.add_argument("--data-path", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz")) parser.add_argument("--data-path", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"))
parser.add_argument("--data-glob", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"))
parser.add_argument("--split-path", default=str(base_dir / "feature_split.json")) parser.add_argument("--split-path", default=str(base_dir / "feature_split.json"))
parser.add_argument("--stats-path", default=str(base_dir / "results" / "cont_stats.json")) parser.add_argument("--stats-path", default=str(base_dir / "results" / "cont_stats.json"))
parser.add_argument("--vocab-path", default=str(base_dir / "results" / "disc_vocab.json")) parser.add_argument("--vocab-path", default=str(base_dir / "results" / "disc_vocab.json"))
@@ -64,6 +65,11 @@ def parse_args():
parser.add_argument("--batch-size", type=int, default=2) parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto") parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
parser.add_argument("--include-time", action="store_true", help="Include time column as a simple index") parser.add_argument("--include-time", action="store_true", help="Include time column as a simple index")
parser.add_argument("--clip-k", type=float, default=5.0, help="Clip continuous values to mean±k*std")
parser.add_argument("--use-ema", action="store_true", help="Use EMA weights if available")
parser.add_argument("--config", default=None, help="Optional config_used.json to infer conditioning")
parser.add_argument("--condition-id", type=int, default=-1, help="Condition file id (0..N-1), -1=random")
parser.add_argument("--include-condition", action="store_true", help="Include condition id column in CSV")
return parser.parse_args() return parser.parse_args()
@@ -75,6 +81,15 @@ def main():
if not os.path.exists(args.model_path): if not os.path.exists(args.model_path):
raise SystemExit("missing model file: %s" % args.model_path) raise SystemExit("missing model file: %s" % args.model_path)
# resolve header source
data_path = args.data_path
if args.data_glob:
base = Path(args.data_glob).parent
pat = Path(args.data_glob).name
matches = sorted(base.glob(pat))
if matches:
data_path = str(matches[0])
split = load_split(args.split_path) split = load_split(args.split_path)
time_col = split.get("time_column", "time") time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col] cont_cols = [c for c in split["continuous"] if c != time_col]
@@ -89,8 +104,31 @@ def main():
vocab_sizes = [len(vocab[c]) for c in disc_cols] vocab_sizes = [len(vocab[c]) for c in disc_cols]
device = resolve_device(args.device) device = resolve_device(args.device)
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(device) cfg = {}
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True)) use_condition = False
cond_vocab_size = 0
if args.config and os.path.exists(args.config):
with open(args.config, "r", encoding="utf-8") as f:
cfg = json.load(f)
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
if use_condition:
base = Path(cfg.get("data_glob", args.data_glob)).parent
pat = Path(cfg.get("data_glob", args.data_glob)).name
cond_vocab_size = len(sorted(base.glob(pat)))
model = HybridDiffusionModel(
cont_dim=len(cont_cols),
disc_vocab_sizes=vocab_sizes,
cond_vocab_size=cond_vocab_size if use_condition else 0,
cond_dim=int(cfg.get("cond_dim", 32)),
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
eps_scale=float(cfg.get("eps_scale", 1.0)),
).to(device)
if args.use_ema and os.path.exists(args.model_path.replace("model.pt", "model_ema.pt")):
ema_path = args.model_path.replace("model.pt", "model_ema.pt")
model.load_state_dict(torch.load(ema_path, map_location=device, weights_only=True))
else:
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
model.eval() model.eval()
betas = cosine_beta_schedule(args.timesteps).to(device) betas = cosine_beta_schedule(args.timesteps).to(device)
@@ -108,9 +146,20 @@ def main():
for i in range(len(disc_cols)): for i in range(len(disc_cols)):
x_disc[:, :, i] = mask_tokens[i] x_disc[:, :, i] = mask_tokens[i]
# condition id
cond = None
if use_condition:
if cond_vocab_size <= 0:
raise SystemExit("use_condition enabled but no files matched data_glob")
if args.condition_id < 0:
cond_id = torch.randint(0, cond_vocab_size, (args.batch_size,), device=device)
else:
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
cond = cond_id
for t in reversed(range(args.timesteps)): for t in reversed(range(args.timesteps)):
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long) t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch) eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
a_t = alphas[t] a_t = alphas[t]
a_bar_t = alphas_cumprod[t] a_bar_t = alphas_cumprod[t]
@@ -122,6 +171,8 @@ def main():
x_cont = mean_x + torch.sqrt(betas[t]) * noise x_cont = mean_x + torch.sqrt(betas[t]) * noise
else: else:
x_cont = mean_x x_cont = mean_x
if args.clip_k > 0:
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
for i, logit in enumerate(logits): for i, logit in enumerate(logits):
if t == 0: if t == 0:
@@ -136,15 +187,22 @@ def main():
) )
x_disc[:, :, i][mask] = sampled[mask] x_disc[:, :, i][mask] = sampled[mask]
# move to CPU for export
x_cont = x_cont.cpu() x_cont = x_cont.cpu()
x_disc = x_disc.cpu() x_disc = x_disc.cpu()
# clip in normalized space to avoid extreme blow-up
if args.clip_k > 0:
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype) mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype) std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
x_cont = x_cont * std_vec + mean_vec x_cont = x_cont * std_vec + mean_vec
header = read_header(args.data_path) header = read_header(data_path)
out_cols = [c for c in header if c != time_col or args.include_time] out_cols = [c for c in header if c != time_col or args.include_time]
if args.include_condition and use_condition:
out_cols = ["__cond_file_id"] + out_cols
os.makedirs(os.path.dirname(args.out), exist_ok=True) os.makedirs(os.path.dirname(args.out), exist_ok=True)
with open(args.out, "w", newline="", encoding="utf-8") as f: with open(args.out, "w", newline="", encoding="utf-8") as f:
@@ -155,6 +213,8 @@ def main():
for b in range(args.batch_size): for b in range(args.batch_size):
for t in range(args.seq_len): for t in range(args.seq_len):
row = {} row = {}
if args.include_condition and use_condition:
row["__cond_file_id"] = str(int(cond[b].item())) if cond is not None else "-1"
if args.include_time and time_col in header: if args.include_time and time_col in header:
row[time_col] = str(row_index) row[time_col] = str(row_index)
for i, c in enumerate(cont_cols): for i, c in enumerate(cont_cols):

View File

@@ -36,9 +36,10 @@ def q_sample_discrete(
mask_tokens: torch.Tensor, mask_tokens: torch.Tensor,
max_t: int, max_t: int,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Randomly mask discrete tokens with a linear schedule over t.""" """Randomly mask discrete tokens with a cosine schedule over t."""
bsz = x0.size(0) bsz = x0.size(0)
p = t.float() / float(max_t) # cosine schedule: p(0)=0, p(max_t)=1
p = 0.5 * (1.0 - torch.cos(math.pi * t.float() / float(max_t)))
p = p.view(bsz, 1, 1) p = p.view(bsz, 1, 1)
mask = torch.rand_like(x0.float()) < p mask = torch.rand_like(x0.float()) < p
x_masked = x0.clone() x_masked = x0.clone()
@@ -69,12 +70,24 @@ class HybridDiffusionModel(nn.Module):
disc_vocab_sizes: List[int], disc_vocab_sizes: List[int],
time_dim: int = 64, time_dim: int = 64,
hidden_dim: int = 256, hidden_dim: int = 256,
cond_vocab_size: int = 0,
cond_dim: int = 32,
use_tanh_eps: bool = False,
eps_scale: float = 1.0,
): ):
super().__init__() super().__init__()
self.cont_dim = cont_dim self.cont_dim = cont_dim
self.disc_vocab_sizes = disc_vocab_sizes self.disc_vocab_sizes = disc_vocab_sizes
self.time_embed = SinusoidalTimeEmbedding(time_dim) self.time_embed = SinusoidalTimeEmbedding(time_dim)
self.use_tanh_eps = use_tanh_eps
self.eps_scale = eps_scale
self.cond_vocab_size = cond_vocab_size
self.cond_dim = cond_dim
self.cond_embed = None
if cond_vocab_size and cond_vocab_size > 0:
self.cond_embed = nn.Embedding(cond_vocab_size, cond_dim)
self.disc_embeds = nn.ModuleList([ self.disc_embeds = nn.ModuleList([
nn.Embedding(vocab_size + 1, min(32, vocab_size * 2)) nn.Embedding(vocab_size + 1, min(32, vocab_size * 2))
@@ -83,7 +96,8 @@ class HybridDiffusionModel(nn.Module):
disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds) disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds)
self.cont_proj = nn.Linear(cont_dim, cont_dim) self.cont_proj = nn.Linear(cont_dim, cont_dim)
self.in_proj = nn.Linear(cont_dim + disc_embed_dim + time_dim, hidden_dim) in_dim = cont_dim + disc_embed_dim + time_dim + (cond_dim if self.cond_embed is not None else 0)
self.in_proj = nn.Linear(in_dim, hidden_dim)
self.backbone = nn.GRU(hidden_dim, hidden_dim, batch_first=True) self.backbone = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
self.cont_head = nn.Linear(hidden_dim, cont_dim) self.cont_head = nn.Linear(hidden_dim, cont_dim)
@@ -92,7 +106,7 @@ class HybridDiffusionModel(nn.Module):
for vocab_size in disc_vocab_sizes for vocab_size in disc_vocab_sizes
]) ])
def forward(self, x_cont: torch.Tensor, x_disc: torch.Tensor, t: torch.Tensor): def forward(self, x_cont: torch.Tensor, x_disc: torch.Tensor, t: torch.Tensor, cond: torch.Tensor = None):
"""x_cont: (B,T,Cc), x_disc: (B,T,Cd) with integer tokens.""" """x_cont: (B,T,Cc), x_disc: (B,T,Cd) with integer tokens."""
time_emb = self.time_embed(t) time_emb = self.time_embed(t)
time_emb = time_emb.unsqueeze(1).expand(-1, x_cont.size(1), -1) time_emb = time_emb.unsqueeze(1).expand(-1, x_cont.size(1), -1)
@@ -102,12 +116,23 @@ class HybridDiffusionModel(nn.Module):
disc_embs.append(emb(x_disc[:, :, i])) disc_embs.append(emb(x_disc[:, :, i]))
disc_feat = torch.cat(disc_embs, dim=-1) disc_feat = torch.cat(disc_embs, dim=-1)
cond_feat = None
if self.cond_embed is not None:
if cond is None:
raise ValueError("cond is required when cond_vocab_size > 0")
cond_feat = self.cond_embed(cond).unsqueeze(1).expand(-1, x_cont.size(1), -1)
cont_feat = self.cont_proj(x_cont) cont_feat = self.cont_proj(x_cont)
feat = torch.cat([cont_feat, disc_feat, time_emb], dim=-1) parts = [cont_feat, disc_feat, time_emb]
if cond_feat is not None:
parts.append(cond_feat)
feat = torch.cat(parts, dim=-1)
feat = self.in_proj(feat) feat = self.in_proj(feat)
out, _ = self.backbone(feat) out, _ = self.backbone(feat)
eps_pred = self.cont_head(out) eps_pred = self.cont_head(out)
if self.use_tanh_eps:
eps_pred = torch.tanh(eps_pred) * self.eps_scale
logits = [head(out) for head in self.disc_heads] logits = [head(out) for head in self.disc_heads]
return eps_pred, logits return eps_pred, logits

View File

@@ -10,7 +10,7 @@ from platform_utils import safe_path, ensure_dir
BASE_DIR = Path(__file__).resolve().parent BASE_DIR = Path(__file__).resolve().parent
REPO_DIR = BASE_DIR.parent.parent REPO_DIR = BASE_DIR.parent.parent
DATA_PATH = REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz" DATA_GLOB = REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"
SPLIT_PATH = BASE_DIR / "feature_split.json" SPLIT_PATH = BASE_DIR / "feature_split.json"
OUT_STATS = BASE_DIR / "results" / "cont_stats.json" OUT_STATS = BASE_DIR / "results" / "cont_stats.json"
OUT_VOCAB = BASE_DIR / "results" / "disc_vocab.json" OUT_VOCAB = BASE_DIR / "results" / "disc_vocab.json"
@@ -22,8 +22,13 @@ def main(max_rows: Optional[int] = None):
cont_cols = [c for c in split["continuous"] if c != time_col] cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col] disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
mean, std = compute_cont_stats(safe_path(DATA_PATH), cont_cols, max_rows=max_rows) data_paths = sorted(Path(REPO_DIR / "dataset" / "hai" / "hai-21.03").glob("train*.csv.gz"))
vocab = build_vocab(safe_path(DATA_PATH), disc_cols, max_rows=max_rows) if not data_paths:
raise SystemExit("no train files found under %s" % str(DATA_GLOB))
data_paths = [safe_path(p) for p in data_paths]
mean, std = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows)
vocab = build_vocab(data_paths, disc_cols, max_rows=max_rows)
ensure_dir(OUT_STATS.parent) ensure_dir(OUT_STATS.parent)
with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f: with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f:

View File

@@ -5,6 +5,7 @@ import argparse
import subprocess import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
import json
from platform_utils import safe_path, is_windows from platform_utils import safe_path, is_windows
@@ -40,6 +41,13 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
base_dir = Path(__file__).resolve().parent base_dir = Path(__file__).resolve().parent
config_path = Path(args.config)
with open(config_path, "r", encoding="utf-8") as f:
cfg = json.load(f)
timesteps = cfg.get("timesteps", 200)
seq_len = cfg.get("sample_seq_len", cfg.get("seq_len", 64))
batch_size = cfg.get("sample_batch_size", cfg.get("batch_size", 2))
clip_k = cfg.get("clip_k", 5.0)
run([sys.executable, str(base_dir / "prepare_data.py")]) run([sys.executable, str(base_dir / "prepare_data.py")])
run([sys.executable, str(base_dir / "train.py"), "--config", args.config, "--device", args.device]) run([sys.executable, str(base_dir / "train.py"), "--config", args.config, "--device", args.device])
run( run(
@@ -49,6 +57,17 @@ def main():
"--include-time", "--include-time",
"--device", "--device",
args.device, args.device,
"--config",
str(config_path),
"--timesteps",
str(timesteps),
"--seq-len",
str(seq_len),
"--batch-size",
str(batch_size),
"--clip-k",
str(clip_k),
"--use-ema",
] ]
) )
run([sys.executable, str(base_dir / "evaluate_generated.py")]) run([sys.executable, str(base_dir / "evaluate_generated.py")])

View File

@@ -17,6 +17,7 @@ BASE_DIR = Path(__file__).resolve().parent
SPLIT_PATH = BASE_DIR / "feature_split.json" SPLIT_PATH = BASE_DIR / "feature_split.json"
VOCAB_PATH = BASE_DIR / "results" / "disc_vocab.json" VOCAB_PATH = BASE_DIR / "results" / "disc_vocab.json"
MODEL_PATH = BASE_DIR / "results" / "model.pt" MODEL_PATH = BASE_DIR / "results" / "model.pt"
CONFIG_PATH = BASE_DIR / "results" / "config_used.json"
# 使用 platform_utils 中的 resolve_device 函数 # 使用 platform_utils 中的 resolve_device 函数
@@ -25,6 +26,7 @@ DEVICE = resolve_device("auto")
TIMESTEPS = 200 TIMESTEPS = 200
SEQ_LEN = 64 SEQ_LEN = 64
BATCH_SIZE = 2 BATCH_SIZE = 2
CLIP_K = 5.0
def load_vocab(): def load_vocab():
@@ -33,6 +35,19 @@ def load_vocab():
def main(): def main():
cfg = {}
if CONFIG_PATH.exists():
with open(str(CONFIG_PATH), "r", encoding="utf-8") as f:
cfg = json.load(f)
timesteps = int(cfg.get("timesteps", TIMESTEPS))
seq_len = int(cfg.get("sample_seq_len", cfg.get("seq_len", SEQ_LEN)))
batch_size = int(cfg.get("sample_batch_size", cfg.get("batch_size", BATCH_SIZE)))
clip_k = float(cfg.get("clip_k", CLIP_K))
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
cond_dim = int(cfg.get("cond_dim", 32))
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
eps_scale = float(cfg.get("eps_scale", 1.0))
split = load_split(str(SPLIT_PATH)) split = load_split(str(SPLIT_PATH))
time_col = split.get("time_column", "time") time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col] cont_cols = [c for c in split["continuous"] if c != time_col]
@@ -42,26 +57,46 @@ def main():
vocab_sizes = [len(vocab[c]) for c in disc_cols] vocab_sizes = [len(vocab[c]) for c in disc_cols]
print("device", DEVICE) print("device", DEVICE)
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE) cond_vocab_size = 0
if use_condition:
data_glob = cfg.get("data_glob")
if data_glob:
base = Path(data_glob).parent
pat = Path(data_glob).name
cond_vocab_size = len(sorted(base.glob(pat)))
model = HybridDiffusionModel(
cont_dim=len(cont_cols),
disc_vocab_sizes=vocab_sizes,
cond_vocab_size=cond_vocab_size,
cond_dim=cond_dim,
use_tanh_eps=use_tanh_eps,
eps_scale=eps_scale,
).to(DEVICE)
if MODEL_PATH.exists(): if MODEL_PATH.exists():
model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True)) model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True))
model.eval() model.eval()
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE) betas = cosine_beta_schedule(timesteps).to(DEVICE)
alphas = 1.0 - betas alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod = torch.cumprod(alphas, dim=0)
x_cont = torch.randn(BATCH_SIZE, SEQ_LEN, len(cont_cols), device=DEVICE) x_cont = torch.randn(batch_size, seq_len, len(cont_cols), device=DEVICE)
x_disc = torch.full((BATCH_SIZE, SEQ_LEN, len(disc_cols)), 0, device=DEVICE, dtype=torch.long) x_disc = torch.full((batch_size, seq_len, len(disc_cols)), 0, device=DEVICE, dtype=torch.long)
mask_tokens = torch.tensor(vocab_sizes, device=DEVICE) mask_tokens = torch.tensor(vocab_sizes, device=DEVICE)
# Initialize discrete with mask tokens # Initialize discrete with mask tokens
for i in range(len(disc_cols)): for i in range(len(disc_cols)):
x_disc[:, :, i] = mask_tokens[i] x_disc[:, :, i] = mask_tokens[i]
for t in reversed(range(TIMESTEPS)): cond = None
t_batch = torch.full((BATCH_SIZE,), t, device=DEVICE, dtype=torch.long) if use_condition:
eps_pred, logits = model(x_cont, x_disc, t_batch) if cond_vocab_size <= 0:
raise SystemExit("use_condition enabled but no files matched data_glob")
cond = torch.randint(0, cond_vocab_size, (batch_size,), device=DEVICE, dtype=torch.long)
for t in reversed(range(timesteps)):
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
# Continuous reverse step (DDPM): x_{t-1} mean # Continuous reverse step (DDPM): x_{t-1} mean
a_t = alphas[t] a_t = alphas[t]
@@ -74,6 +109,8 @@ def main():
x_cont = mean + torch.sqrt(betas[t]) * noise x_cont = mean + torch.sqrt(betas[t]) * noise
else: else:
x_cont = mean x_cont = mean
if clip_k > 0:
x_cont = torch.clamp(x_cont, -clip_k, clip_k)
# Discrete: fill masked positions by sampling logits # Discrete: fill masked positions by sampling logits
for i, logit in enumerate(logits): for i, logit in enumerate(logits):

View File

@@ -26,6 +26,7 @@ REPO_DIR = BASE_DIR.parent.parent
DEFAULTS = { DEFAULTS = {
"data_path": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz", "data_path": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz",
"data_glob": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz",
"split_path": BASE_DIR / "feature_split.json", "split_path": BASE_DIR / "feature_split.json",
"stats_path": BASE_DIR / "results" / "cont_stats.json", "stats_path": BASE_DIR / "results" / "cont_stats.json",
"vocab_path": BASE_DIR / "results" / "disc_vocab.json", "vocab_path": BASE_DIR / "results" / "disc_vocab.json",
@@ -41,6 +42,15 @@ DEFAULTS = {
"seed": 1337, "seed": 1337,
"log_every": 10, "log_every": 10,
"ckpt_every": 50, "ckpt_every": 50,
"ema_decay": 0.999,
"use_ema": True,
"clip_k": 5.0,
"grad_clip": 1.0,
"use_condition": True,
"condition_type": "file_id",
"cond_dim": 32,
"use_tanh_eps": True,
"eps_scale": 1.0,
} }
@@ -69,7 +79,7 @@ def parse_args():
def resolve_config_paths(config, base_dir: Path): def resolve_config_paths(config, base_dir: Path):
keys = ["data_path", "split_path", "stats_path", "vocab_path", "out_dir"] keys = ["data_path", "data_glob", "split_path", "stats_path", "vocab_path", "out_dir"]
for key in keys: for key in keys:
if key in config: if key in config:
# 如果值是字符串转换为Path对象 # 如果值是字符串转换为Path对象
@@ -85,6 +95,26 @@ def resolve_config_paths(config, base_dir: Path):
return config return config
class EMA:
def __init__(self, model, decay: float):
self.decay = decay
self.shadow = {}
for name, param in model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.detach().clone()
def update(self, model):
with torch.no_grad():
for name, param in model.named_parameters():
if not param.requires_grad:
continue
old = self.shadow[name]
self.shadow[name] = old * self.decay + param.detach() * (1.0 - self.decay)
def state_dict(self):
return self.shadow
def main(): def main():
args = parse_args() args = parse_args()
config = dict(DEFAULTS) config = dict(DEFAULTS)
@@ -113,25 +143,47 @@ def main():
vocab = load_json(config["vocab_path"])["vocab"] vocab = load_json(config["vocab_path"])["vocab"]
vocab_sizes = [len(vocab[c]) for c in disc_cols] vocab_sizes = [len(vocab[c]) for c in disc_cols]
data_paths = None
if "data_glob" in config and config["data_glob"]:
data_paths = sorted(Path(config["data_glob"]).parent.glob(Path(config["data_glob"]).name))
if data_paths:
data_paths = [safe_path(p) for p in data_paths]
if not data_paths:
data_paths = [safe_path(config["data_path"])]
use_condition = bool(config.get("use_condition")) and config.get("condition_type") == "file_id"
cond_vocab_size = len(data_paths) if use_condition else 0
device = resolve_device(str(config["device"])) device = resolve_device(str(config["device"]))
print("device", device) print("device", device)
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(device) model = HybridDiffusionModel(
cont_dim=len(cont_cols),
disc_vocab_sizes=vocab_sizes,
cond_vocab_size=cond_vocab_size,
cond_dim=int(config.get("cond_dim", 32)),
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
eps_scale=float(config.get("eps_scale", 1.0)),
).to(device)
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"])) opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
betas = cosine_beta_schedule(int(config["timesteps"])).to(device) betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
alphas = 1.0 - betas alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod = torch.cumprod(alphas, dim=0)
os.makedirs(config["out_dir"], exist_ok=True) os.makedirs(config["out_dir"], exist_ok=True)
log_path = os.path.join(config["out_dir"], "train_log.csv") out_dir = safe_path(config["out_dir"])
log_path = os.path.join(out_dir, "train_log.csv")
with open(log_path, "w", encoding="utf-8") as f: with open(log_path, "w", encoding="utf-8") as f:
f.write("epoch,step,loss,loss_cont,loss_disc\n") f.write("epoch,step,loss,loss_cont,loss_disc\n")
with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f:
json.dump(config, f, indent=2)
total_step = 0 total_step = 0
for epoch in range(int(config["epochs"])): for epoch in range(int(config["epochs"])):
for step, (x_cont, x_disc) in enumerate( for step, batch in enumerate(
windowed_batches( windowed_batches(
config["data_path"], data_paths,
cont_cols, cont_cols,
disc_cols, disc_cols,
vocab, vocab,
@@ -140,8 +192,15 @@ def main():
batch_size=int(config["batch_size"]), batch_size=int(config["batch_size"]),
seq_len=int(config["seq_len"]), seq_len=int(config["seq_len"]),
max_batches=int(config["max_batches"]), max_batches=int(config["max_batches"]),
return_file_id=use_condition,
) )
): ):
if use_condition:
x_cont, x_disc, cond = batch
cond = cond.to(device)
else:
x_cont, x_disc = batch
cond = None
x_cont = x_cont.to(device) x_cont = x_cont.to(device)
x_disc = x_disc.to(device) x_disc = x_disc.to(device)
@@ -153,21 +212,29 @@ def main():
mask_tokens = torch.tensor(vocab_sizes, device=device) mask_tokens = torch.tensor(vocab_sizes, device=device)
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"])) x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"]))
eps_pred, logits = model(x_cont_t, x_disc_t, t) eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)
loss_cont = F.mse_loss(eps_pred, noise) loss_cont = F.mse_loss(eps_pred, noise)
loss_disc = 0.0 loss_disc = 0.0
loss_disc_count = 0
for i, logit in enumerate(logits): for i, logit in enumerate(logits):
if mask[:, :, i].any(): if mask[:, :, i].any():
loss_disc = loss_disc + F.cross_entropy( loss_disc = loss_disc + F.cross_entropy(
logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]] logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]]
) )
loss_disc_count += 1
if loss_disc_count > 0:
loss_disc = loss_disc / loss_disc_count
lam = float(config["lambda"]) lam = float(config["lambda"])
loss = lam * loss_cont + (1 - lam) * loss_disc loss = lam * loss_cont + (1 - lam) * loss_disc
opt.zero_grad() opt.zero_grad()
loss.backward() loss.backward()
if float(config.get("grad_clip", 0.0)) > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"]))
opt.step() opt.step()
if ema is not None:
ema.update(model)
if step % int(config["log_every"]) == 0: if step % int(config["log_every"]) == 0:
print("epoch", epoch, "step", step, "loss", float(loss)) print("epoch", epoch, "step", step, "loss", float(loss))
@@ -185,9 +252,13 @@ def main():
"config": config, "config": config,
"step": total_step, "step": total_step,
} }
torch.save(ckpt, os.path.join(config["out_dir"], "model_ckpt.pt")) if ema is not None:
ckpt["ema"] = ema.state_dict()
torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt"))
torch.save(model.state_dict(), os.path.join(config["out_dir"], "model.pt")) torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
if ema is not None:
torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt"))
if __name__ == "__main__": if __name__ == "__main__":

3
requirements.txt Normal file
View File

@@ -0,0 +1,3 @@
torch
numpy
matplotlib