update
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# Example: HAI 21.03 Feature Split
|
||||
|
||||
This folder contains a small, reproducible example that inspects the HAI 21.03
|
||||
CSV (train1) and produces a continuous/discrete split using a simple heuristic.
|
||||
CSV (all train*.csv.gz files) and produces a continuous/discrete split using a simple heuristic.
|
||||
|
||||
## Files
|
||||
- analyze_hai21_03.py: reads a sample of the data and writes results.
|
||||
@@ -60,6 +60,12 @@ python example/run_pipeline.py --device auto
|
||||
- Set `device` in `example/config.json` to `auto` or `cuda` when moving to a GPU machine.
|
||||
- Attack label columns (`attack*`) are excluded from training and generation.
|
||||
- `time` column is always excluded from training and generation (optional for export only).
|
||||
- EMA weights are saved as `model_ema.pt` and used by the pipeline for sampling.
|
||||
- Gradients are clipped by default (`grad_clip` in `config.json`) to stabilize training.
|
||||
- Discrete masking uses a cosine schedule for smoother corruption.
|
||||
- Continuous sampling is clipped in normalized space each step for stability.
|
||||
- Optional conditioning by file id (`train*.csv.gz`) is enabled by default for multi-file training.
|
||||
- Continuous head can be bounded with `tanh` via `use_tanh_eps` in config.
|
||||
- The script only samples the first 5000 rows to stay fast.
|
||||
- `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it.
|
||||
- `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU.
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
{
|
||||
"data_path": "../../dataset/hai/hai-21.03/train1.csv.gz",
|
||||
"data_glob": "../../dataset/hai/hai-21.03/train*.csv.gz",
|
||||
"split_path": "./feature_split.json",
|
||||
"stats_path": "./results/cont_stats.json",
|
||||
"vocab_path": "./results/disc_vocab.json",
|
||||
@@ -14,5 +15,16 @@
|
||||
"lr": 0.0005,
|
||||
"seed": 1337,
|
||||
"log_every": 10,
|
||||
"ckpt_every": 50
|
||||
"ckpt_every": 50,
|
||||
"ema_decay": 0.999,
|
||||
"use_ema": true,
|
||||
"clip_k": 5.0,
|
||||
"grad_clip": 1.0,
|
||||
"use_condition": true,
|
||||
"condition_type": "file_id",
|
||||
"cond_dim": 32,
|
||||
"use_tanh_eps": true,
|
||||
"eps_scale": 1.0,
|
||||
"sample_batch_size": 8,
|
||||
"sample_seq_len": 128
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import csv
|
||||
import gzip
|
||||
import json
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
|
||||
@@ -13,15 +13,18 @@ def load_split(path: str) -> Dict[str, List[str]]:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def iter_rows(path: str) -> Iterable[Dict[str, str]]:
|
||||
with gzip.open(path, "rt", newline="") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
yield row
|
||||
def iter_rows(path_or_paths: Union[str, List[str]]) -> Iterable[Dict[str, str]]:
|
||||
paths = [path_or_paths] if isinstance(path_or_paths, str) else list(path_or_paths)
|
||||
for path in paths:
|
||||
opener = gzip.open if str(path).endswith(".gz") else open
|
||||
with opener(path, "rt", newline="") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
yield row
|
||||
|
||||
|
||||
def compute_cont_stats(
|
||||
path: str,
|
||||
path: Union[str, List[str]],
|
||||
cont_cols: List[str],
|
||||
max_rows: Optional[int] = None,
|
||||
) -> Tuple[Dict[str, float], Dict[str, float]]:
|
||||
@@ -52,7 +55,7 @@ def compute_cont_stats(
|
||||
|
||||
|
||||
def build_vocab(
|
||||
path: str,
|
||||
path: Union[str, List[str]],
|
||||
disc_cols: List[str],
|
||||
max_rows: Optional[int] = None,
|
||||
) -> Dict[str, Dict[str, int]]:
|
||||
@@ -80,7 +83,7 @@ def normalize_cont(x, cont_cols: List[str], mean: Dict[str, float], std: Dict[st
|
||||
|
||||
|
||||
def windowed_batches(
|
||||
path: str,
|
||||
path: Union[str, List[str]],
|
||||
cont_cols: List[str],
|
||||
disc_cols: List[str],
|
||||
vocab: Dict[str, Dict[str, int]],
|
||||
@@ -89,10 +92,12 @@ def windowed_batches(
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
max_batches: Optional[int] = None,
|
||||
return_file_id: bool = False,
|
||||
):
|
||||
import torch
|
||||
batch_cont = []
|
||||
batch_disc = []
|
||||
batch_file = []
|
||||
seq_cont = []
|
||||
seq_disc = []
|
||||
|
||||
@@ -105,22 +110,34 @@ def windowed_batches(
|
||||
seq_disc = []
|
||||
|
||||
batches_yielded = 0
|
||||
for row in iter_rows(path):
|
||||
cont_row = [float(row[c]) for c in cont_cols]
|
||||
disc_row = [vocab[c].get(row[c], vocab[c]["<UNK>"]) for c in disc_cols]
|
||||
seq_cont.append(cont_row)
|
||||
seq_disc.append(disc_row)
|
||||
if len(seq_cont) == seq_len:
|
||||
flush_seq()
|
||||
if len(batch_cont) == batch_size:
|
||||
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
|
||||
x_disc = torch.tensor(batch_disc, dtype=torch.long)
|
||||
x_cont = normalize_cont(x_cont, cont_cols, mean, std)
|
||||
yield x_cont, x_disc
|
||||
batch_cont = []
|
||||
batch_disc = []
|
||||
batches_yielded += 1
|
||||
if max_batches is not None and batches_yielded >= max_batches:
|
||||
return
|
||||
paths = [path] if isinstance(path, str) else list(path)
|
||||
for file_id, p in enumerate(paths):
|
||||
for row in iter_rows(p):
|
||||
cont_row = [float(row[c]) for c in cont_cols]
|
||||
disc_row = [vocab[c].get(row[c], vocab[c]["<UNK>"]) for c in disc_cols]
|
||||
seq_cont.append(cont_row)
|
||||
seq_disc.append(disc_row)
|
||||
if len(seq_cont) == seq_len:
|
||||
flush_seq()
|
||||
if return_file_id:
|
||||
batch_file.append(file_id)
|
||||
if len(batch_cont) == batch_size:
|
||||
x_cont = torch.tensor(batch_cont, dtype=torch.float32)
|
||||
x_disc = torch.tensor(batch_disc, dtype=torch.long)
|
||||
x_cont = normalize_cont(x_cont, cont_cols, mean, std)
|
||||
if return_file_id:
|
||||
x_file = torch.tensor(batch_file, dtype=torch.long)
|
||||
yield x_cont, x_disc, x_file
|
||||
else:
|
||||
yield x_cont, x_disc
|
||||
batch_cont = []
|
||||
batch_disc = []
|
||||
batch_file = []
|
||||
batches_yielded += 1
|
||||
if max_batches is not None and batches_yielded >= max_batches:
|
||||
return
|
||||
# drop partial sequence at file boundary
|
||||
seq_cont = []
|
||||
seq_disc = []
|
||||
|
||||
# Drop last partial batch for simplicity
|
||||
|
||||
@@ -54,6 +54,7 @@ def parse_args():
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
repo_dir = base_dir.parent.parent
|
||||
parser.add_argument("--data-path", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"))
|
||||
parser.add_argument("--data-glob", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"))
|
||||
parser.add_argument("--split-path", default=str(base_dir / "feature_split.json"))
|
||||
parser.add_argument("--stats-path", default=str(base_dir / "results" / "cont_stats.json"))
|
||||
parser.add_argument("--vocab-path", default=str(base_dir / "results" / "disc_vocab.json"))
|
||||
@@ -64,6 +65,11 @@ def parse_args():
|
||||
parser.add_argument("--batch-size", type=int, default=2)
|
||||
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
|
||||
parser.add_argument("--include-time", action="store_true", help="Include time column as a simple index")
|
||||
parser.add_argument("--clip-k", type=float, default=5.0, help="Clip continuous values to mean±k*std")
|
||||
parser.add_argument("--use-ema", action="store_true", help="Use EMA weights if available")
|
||||
parser.add_argument("--config", default=None, help="Optional config_used.json to infer conditioning")
|
||||
parser.add_argument("--condition-id", type=int, default=-1, help="Condition file id (0..N-1), -1=random")
|
||||
parser.add_argument("--include-condition", action="store_true", help="Include condition id column in CSV")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -75,6 +81,15 @@ def main():
|
||||
if not os.path.exists(args.model_path):
|
||||
raise SystemExit("missing model file: %s" % args.model_path)
|
||||
|
||||
# resolve header source
|
||||
data_path = args.data_path
|
||||
if args.data_glob:
|
||||
base = Path(args.data_glob).parent
|
||||
pat = Path(args.data_glob).name
|
||||
matches = sorted(base.glob(pat))
|
||||
if matches:
|
||||
data_path = str(matches[0])
|
||||
|
||||
split = load_split(args.split_path)
|
||||
time_col = split.get("time_column", "time")
|
||||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||
@@ -89,8 +104,31 @@ def main():
|
||||
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||
|
||||
device = resolve_device(args.device)
|
||||
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(device)
|
||||
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
|
||||
cfg = {}
|
||||
use_condition = False
|
||||
cond_vocab_size = 0
|
||||
if args.config and os.path.exists(args.config):
|
||||
with open(args.config, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
|
||||
if use_condition:
|
||||
base = Path(cfg.get("data_glob", args.data_glob)).parent
|
||||
pat = Path(cfg.get("data_glob", args.data_glob)).name
|
||||
cond_vocab_size = len(sorted(base.glob(pat)))
|
||||
|
||||
model = HybridDiffusionModel(
|
||||
cont_dim=len(cont_cols),
|
||||
disc_vocab_sizes=vocab_sizes,
|
||||
cond_vocab_size=cond_vocab_size if use_condition else 0,
|
||||
cond_dim=int(cfg.get("cond_dim", 32)),
|
||||
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
|
||||
eps_scale=float(cfg.get("eps_scale", 1.0)),
|
||||
).to(device)
|
||||
if args.use_ema and os.path.exists(args.model_path.replace("model.pt", "model_ema.pt")):
|
||||
ema_path = args.model_path.replace("model.pt", "model_ema.pt")
|
||||
model.load_state_dict(torch.load(ema_path, map_location=device, weights_only=True))
|
||||
else:
|
||||
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
|
||||
model.eval()
|
||||
|
||||
betas = cosine_beta_schedule(args.timesteps).to(device)
|
||||
@@ -108,9 +146,20 @@ def main():
|
||||
for i in range(len(disc_cols)):
|
||||
x_disc[:, :, i] = mask_tokens[i]
|
||||
|
||||
# condition id
|
||||
cond = None
|
||||
if use_condition:
|
||||
if cond_vocab_size <= 0:
|
||||
raise SystemExit("use_condition enabled but no files matched data_glob")
|
||||
if args.condition_id < 0:
|
||||
cond_id = torch.randint(0, cond_vocab_size, (args.batch_size,), device=device)
|
||||
else:
|
||||
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
|
||||
cond = cond_id
|
||||
|
||||
for t in reversed(range(args.timesteps)):
|
||||
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
|
||||
eps_pred, logits = model(x_cont, x_disc, t_batch)
|
||||
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
||||
|
||||
a_t = alphas[t]
|
||||
a_bar_t = alphas_cumprod[t]
|
||||
@@ -122,6 +171,8 @@ def main():
|
||||
x_cont = mean_x + torch.sqrt(betas[t]) * noise
|
||||
else:
|
||||
x_cont = mean_x
|
||||
if args.clip_k > 0:
|
||||
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
|
||||
|
||||
for i, logit in enumerate(logits):
|
||||
if t == 0:
|
||||
@@ -136,15 +187,22 @@ def main():
|
||||
)
|
||||
x_disc[:, :, i][mask] = sampled[mask]
|
||||
|
||||
# move to CPU for export
|
||||
x_cont = x_cont.cpu()
|
||||
x_disc = x_disc.cpu()
|
||||
|
||||
# clip in normalized space to avoid extreme blow-up
|
||||
if args.clip_k > 0:
|
||||
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
|
||||
|
||||
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype)
|
||||
x_cont = x_cont * std_vec + mean_vec
|
||||
|
||||
header = read_header(args.data_path)
|
||||
header = read_header(data_path)
|
||||
out_cols = [c for c in header if c != time_col or args.include_time]
|
||||
if args.include_condition and use_condition:
|
||||
out_cols = ["__cond_file_id"] + out_cols
|
||||
|
||||
os.makedirs(os.path.dirname(args.out), exist_ok=True)
|
||||
with open(args.out, "w", newline="", encoding="utf-8") as f:
|
||||
@@ -155,6 +213,8 @@ def main():
|
||||
for b in range(args.batch_size):
|
||||
for t in range(args.seq_len):
|
||||
row = {}
|
||||
if args.include_condition and use_condition:
|
||||
row["__cond_file_id"] = str(int(cond[b].item())) if cond is not None else "-1"
|
||||
if args.include_time and time_col in header:
|
||||
row[time_col] = str(row_index)
|
||||
for i, c in enumerate(cont_cols):
|
||||
|
||||
@@ -36,9 +36,10 @@ def q_sample_discrete(
|
||||
mask_tokens: torch.Tensor,
|
||||
max_t: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Randomly mask discrete tokens with a linear schedule over t."""
|
||||
"""Randomly mask discrete tokens with a cosine schedule over t."""
|
||||
bsz = x0.size(0)
|
||||
p = t.float() / float(max_t)
|
||||
# cosine schedule: p(0)=0, p(max_t)=1
|
||||
p = 0.5 * (1.0 - torch.cos(math.pi * t.float() / float(max_t)))
|
||||
p = p.view(bsz, 1, 1)
|
||||
mask = torch.rand_like(x0.float()) < p
|
||||
x_masked = x0.clone()
|
||||
@@ -69,12 +70,24 @@ class HybridDiffusionModel(nn.Module):
|
||||
disc_vocab_sizes: List[int],
|
||||
time_dim: int = 64,
|
||||
hidden_dim: int = 256,
|
||||
cond_vocab_size: int = 0,
|
||||
cond_dim: int = 32,
|
||||
use_tanh_eps: bool = False,
|
||||
eps_scale: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.cont_dim = cont_dim
|
||||
self.disc_vocab_sizes = disc_vocab_sizes
|
||||
|
||||
self.time_embed = SinusoidalTimeEmbedding(time_dim)
|
||||
self.use_tanh_eps = use_tanh_eps
|
||||
self.eps_scale = eps_scale
|
||||
|
||||
self.cond_vocab_size = cond_vocab_size
|
||||
self.cond_dim = cond_dim
|
||||
self.cond_embed = None
|
||||
if cond_vocab_size and cond_vocab_size > 0:
|
||||
self.cond_embed = nn.Embedding(cond_vocab_size, cond_dim)
|
||||
|
||||
self.disc_embeds = nn.ModuleList([
|
||||
nn.Embedding(vocab_size + 1, min(32, vocab_size * 2))
|
||||
@@ -83,7 +96,8 @@ class HybridDiffusionModel(nn.Module):
|
||||
disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds)
|
||||
|
||||
self.cont_proj = nn.Linear(cont_dim, cont_dim)
|
||||
self.in_proj = nn.Linear(cont_dim + disc_embed_dim + time_dim, hidden_dim)
|
||||
in_dim = cont_dim + disc_embed_dim + time_dim + (cond_dim if self.cond_embed is not None else 0)
|
||||
self.in_proj = nn.Linear(in_dim, hidden_dim)
|
||||
self.backbone = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
|
||||
|
||||
self.cont_head = nn.Linear(hidden_dim, cont_dim)
|
||||
@@ -92,7 +106,7 @@ class HybridDiffusionModel(nn.Module):
|
||||
for vocab_size in disc_vocab_sizes
|
||||
])
|
||||
|
||||
def forward(self, x_cont: torch.Tensor, x_disc: torch.Tensor, t: torch.Tensor):
|
||||
def forward(self, x_cont: torch.Tensor, x_disc: torch.Tensor, t: torch.Tensor, cond: torch.Tensor = None):
|
||||
"""x_cont: (B,T,Cc), x_disc: (B,T,Cd) with integer tokens."""
|
||||
time_emb = self.time_embed(t)
|
||||
time_emb = time_emb.unsqueeze(1).expand(-1, x_cont.size(1), -1)
|
||||
@@ -102,12 +116,23 @@ class HybridDiffusionModel(nn.Module):
|
||||
disc_embs.append(emb(x_disc[:, :, i]))
|
||||
disc_feat = torch.cat(disc_embs, dim=-1)
|
||||
|
||||
cond_feat = None
|
||||
if self.cond_embed is not None:
|
||||
if cond is None:
|
||||
raise ValueError("cond is required when cond_vocab_size > 0")
|
||||
cond_feat = self.cond_embed(cond).unsqueeze(1).expand(-1, x_cont.size(1), -1)
|
||||
|
||||
cont_feat = self.cont_proj(x_cont)
|
||||
feat = torch.cat([cont_feat, disc_feat, time_emb], dim=-1)
|
||||
parts = [cont_feat, disc_feat, time_emb]
|
||||
if cond_feat is not None:
|
||||
parts.append(cond_feat)
|
||||
feat = torch.cat(parts, dim=-1)
|
||||
feat = self.in_proj(feat)
|
||||
|
||||
out, _ = self.backbone(feat)
|
||||
|
||||
eps_pred = self.cont_head(out)
|
||||
if self.use_tanh_eps:
|
||||
eps_pred = torch.tanh(eps_pred) * self.eps_scale
|
||||
logits = [head(out) for head in self.disc_heads]
|
||||
return eps_pred, logits
|
||||
|
||||
@@ -10,7 +10,7 @@ from platform_utils import safe_path, ensure_dir
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
REPO_DIR = BASE_DIR.parent.parent
|
||||
DATA_PATH = REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"
|
||||
DATA_GLOB = REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"
|
||||
SPLIT_PATH = BASE_DIR / "feature_split.json"
|
||||
OUT_STATS = BASE_DIR / "results" / "cont_stats.json"
|
||||
OUT_VOCAB = BASE_DIR / "results" / "disc_vocab.json"
|
||||
@@ -22,8 +22,13 @@ def main(max_rows: Optional[int] = None):
|
||||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
||||
|
||||
mean, std = compute_cont_stats(safe_path(DATA_PATH), cont_cols, max_rows=max_rows)
|
||||
vocab = build_vocab(safe_path(DATA_PATH), disc_cols, max_rows=max_rows)
|
||||
data_paths = sorted(Path(REPO_DIR / "dataset" / "hai" / "hai-21.03").glob("train*.csv.gz"))
|
||||
if not data_paths:
|
||||
raise SystemExit("no train files found under %s" % str(DATA_GLOB))
|
||||
data_paths = [safe_path(p) for p in data_paths]
|
||||
|
||||
mean, std = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows)
|
||||
vocab = build_vocab(data_paths, disc_cols, max_rows=max_rows)
|
||||
|
||||
ensure_dir(OUT_STATS.parent)
|
||||
with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f:
|
||||
|
||||
@@ -5,6 +5,7 @@ import argparse
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
from platform_utils import safe_path, is_windows
|
||||
|
||||
@@ -40,6 +41,13 @@ def parse_args():
|
||||
def main():
|
||||
args = parse_args()
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
config_path = Path(args.config)
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
timesteps = cfg.get("timesteps", 200)
|
||||
seq_len = cfg.get("sample_seq_len", cfg.get("seq_len", 64))
|
||||
batch_size = cfg.get("sample_batch_size", cfg.get("batch_size", 2))
|
||||
clip_k = cfg.get("clip_k", 5.0)
|
||||
run([sys.executable, str(base_dir / "prepare_data.py")])
|
||||
run([sys.executable, str(base_dir / "train.py"), "--config", args.config, "--device", args.device])
|
||||
run(
|
||||
@@ -49,6 +57,17 @@ def main():
|
||||
"--include-time",
|
||||
"--device",
|
||||
args.device,
|
||||
"--config",
|
||||
str(config_path),
|
||||
"--timesteps",
|
||||
str(timesteps),
|
||||
"--seq-len",
|
||||
str(seq_len),
|
||||
"--batch-size",
|
||||
str(batch_size),
|
||||
"--clip-k",
|
||||
str(clip_k),
|
||||
"--use-ema",
|
||||
]
|
||||
)
|
||||
run([sys.executable, str(base_dir / "evaluate_generated.py")])
|
||||
|
||||
@@ -17,6 +17,7 @@ BASE_DIR = Path(__file__).resolve().parent
|
||||
SPLIT_PATH = BASE_DIR / "feature_split.json"
|
||||
VOCAB_PATH = BASE_DIR / "results" / "disc_vocab.json"
|
||||
MODEL_PATH = BASE_DIR / "results" / "model.pt"
|
||||
CONFIG_PATH = BASE_DIR / "results" / "config_used.json"
|
||||
|
||||
# 使用 platform_utils 中的 resolve_device 函数
|
||||
|
||||
@@ -25,6 +26,7 @@ DEVICE = resolve_device("auto")
|
||||
TIMESTEPS = 200
|
||||
SEQ_LEN = 64
|
||||
BATCH_SIZE = 2
|
||||
CLIP_K = 5.0
|
||||
|
||||
|
||||
def load_vocab():
|
||||
@@ -33,6 +35,19 @@ def load_vocab():
|
||||
|
||||
|
||||
def main():
|
||||
cfg = {}
|
||||
if CONFIG_PATH.exists():
|
||||
with open(str(CONFIG_PATH), "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
timesteps = int(cfg.get("timesteps", TIMESTEPS))
|
||||
seq_len = int(cfg.get("sample_seq_len", cfg.get("seq_len", SEQ_LEN)))
|
||||
batch_size = int(cfg.get("sample_batch_size", cfg.get("batch_size", BATCH_SIZE)))
|
||||
clip_k = float(cfg.get("clip_k", CLIP_K))
|
||||
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
|
||||
cond_dim = int(cfg.get("cond_dim", 32))
|
||||
use_tanh_eps = bool(cfg.get("use_tanh_eps", False))
|
||||
eps_scale = float(cfg.get("eps_scale", 1.0))
|
||||
|
||||
split = load_split(str(SPLIT_PATH))
|
||||
time_col = split.get("time_column", "time")
|
||||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||
@@ -42,26 +57,46 @@ def main():
|
||||
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||
|
||||
print("device", DEVICE)
|
||||
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE)
|
||||
cond_vocab_size = 0
|
||||
if use_condition:
|
||||
data_glob = cfg.get("data_glob")
|
||||
if data_glob:
|
||||
base = Path(data_glob).parent
|
||||
pat = Path(data_glob).name
|
||||
cond_vocab_size = len(sorted(base.glob(pat)))
|
||||
model = HybridDiffusionModel(
|
||||
cont_dim=len(cont_cols),
|
||||
disc_vocab_sizes=vocab_sizes,
|
||||
cond_vocab_size=cond_vocab_size,
|
||||
cond_dim=cond_dim,
|
||||
use_tanh_eps=use_tanh_eps,
|
||||
eps_scale=eps_scale,
|
||||
).to(DEVICE)
|
||||
if MODEL_PATH.exists():
|
||||
model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True))
|
||||
model.eval()
|
||||
|
||||
betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)
|
||||
betas = cosine_beta_schedule(timesteps).to(DEVICE)
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
|
||||
x_cont = torch.randn(BATCH_SIZE, SEQ_LEN, len(cont_cols), device=DEVICE)
|
||||
x_disc = torch.full((BATCH_SIZE, SEQ_LEN, len(disc_cols)), 0, device=DEVICE, dtype=torch.long)
|
||||
x_cont = torch.randn(batch_size, seq_len, len(cont_cols), device=DEVICE)
|
||||
x_disc = torch.full((batch_size, seq_len, len(disc_cols)), 0, device=DEVICE, dtype=torch.long)
|
||||
mask_tokens = torch.tensor(vocab_sizes, device=DEVICE)
|
||||
|
||||
# Initialize discrete with mask tokens
|
||||
for i in range(len(disc_cols)):
|
||||
x_disc[:, :, i] = mask_tokens[i]
|
||||
|
||||
for t in reversed(range(TIMESTEPS)):
|
||||
t_batch = torch.full((BATCH_SIZE,), t, device=DEVICE, dtype=torch.long)
|
||||
eps_pred, logits = model(x_cont, x_disc, t_batch)
|
||||
cond = None
|
||||
if use_condition:
|
||||
if cond_vocab_size <= 0:
|
||||
raise SystemExit("use_condition enabled but no files matched data_glob")
|
||||
cond = torch.randint(0, cond_vocab_size, (batch_size,), device=DEVICE, dtype=torch.long)
|
||||
|
||||
for t in reversed(range(timesteps)):
|
||||
t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long)
|
||||
eps_pred, logits = model(x_cont, x_disc, t_batch, cond)
|
||||
|
||||
# Continuous reverse step (DDPM): x_{t-1} mean
|
||||
a_t = alphas[t]
|
||||
@@ -74,6 +109,8 @@ def main():
|
||||
x_cont = mean + torch.sqrt(betas[t]) * noise
|
||||
else:
|
||||
x_cont = mean
|
||||
if clip_k > 0:
|
||||
x_cont = torch.clamp(x_cont, -clip_k, clip_k)
|
||||
|
||||
# Discrete: fill masked positions by sampling logits
|
||||
for i, logit in enumerate(logits):
|
||||
|
||||
@@ -26,6 +26,7 @@ REPO_DIR = BASE_DIR.parent.parent
|
||||
|
||||
DEFAULTS = {
|
||||
"data_path": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz",
|
||||
"data_glob": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz",
|
||||
"split_path": BASE_DIR / "feature_split.json",
|
||||
"stats_path": BASE_DIR / "results" / "cont_stats.json",
|
||||
"vocab_path": BASE_DIR / "results" / "disc_vocab.json",
|
||||
@@ -41,6 +42,15 @@ DEFAULTS = {
|
||||
"seed": 1337,
|
||||
"log_every": 10,
|
||||
"ckpt_every": 50,
|
||||
"ema_decay": 0.999,
|
||||
"use_ema": True,
|
||||
"clip_k": 5.0,
|
||||
"grad_clip": 1.0,
|
||||
"use_condition": True,
|
||||
"condition_type": "file_id",
|
||||
"cond_dim": 32,
|
||||
"use_tanh_eps": True,
|
||||
"eps_scale": 1.0,
|
||||
}
|
||||
|
||||
|
||||
@@ -69,7 +79,7 @@ def parse_args():
|
||||
|
||||
|
||||
def resolve_config_paths(config, base_dir: Path):
|
||||
keys = ["data_path", "split_path", "stats_path", "vocab_path", "out_dir"]
|
||||
keys = ["data_path", "data_glob", "split_path", "stats_path", "vocab_path", "out_dir"]
|
||||
for key in keys:
|
||||
if key in config:
|
||||
# 如果值是字符串,转换为Path对象
|
||||
@@ -85,6 +95,26 @@ def resolve_config_paths(config, base_dir: Path):
|
||||
return config
|
||||
|
||||
|
||||
class EMA:
|
||||
def __init__(self, model, decay: float):
|
||||
self.decay = decay
|
||||
self.shadow = {}
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
self.shadow[name] = param.detach().clone()
|
||||
|
||||
def update(self, model):
|
||||
with torch.no_grad():
|
||||
for name, param in model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
old = self.shadow[name]
|
||||
self.shadow[name] = old * self.decay + param.detach() * (1.0 - self.decay)
|
||||
|
||||
def state_dict(self):
|
||||
return self.shadow
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
config = dict(DEFAULTS)
|
||||
@@ -113,25 +143,47 @@ def main():
|
||||
vocab = load_json(config["vocab_path"])["vocab"]
|
||||
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||
|
||||
data_paths = None
|
||||
if "data_glob" in config and config["data_glob"]:
|
||||
data_paths = sorted(Path(config["data_glob"]).parent.glob(Path(config["data_glob"]).name))
|
||||
if data_paths:
|
||||
data_paths = [safe_path(p) for p in data_paths]
|
||||
if not data_paths:
|
||||
data_paths = [safe_path(config["data_path"])]
|
||||
|
||||
use_condition = bool(config.get("use_condition")) and config.get("condition_type") == "file_id"
|
||||
cond_vocab_size = len(data_paths) if use_condition else 0
|
||||
|
||||
device = resolve_device(str(config["device"]))
|
||||
print("device", device)
|
||||
model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(device)
|
||||
model = HybridDiffusionModel(
|
||||
cont_dim=len(cont_cols),
|
||||
disc_vocab_sizes=vocab_sizes,
|
||||
cond_vocab_size=cond_vocab_size,
|
||||
cond_dim=int(config.get("cond_dim", 32)),
|
||||
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
|
||||
eps_scale=float(config.get("eps_scale", 1.0)),
|
||||
).to(device)
|
||||
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
|
||||
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
|
||||
|
||||
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
|
||||
os.makedirs(config["out_dir"], exist_ok=True)
|
||||
log_path = os.path.join(config["out_dir"], "train_log.csv")
|
||||
out_dir = safe_path(config["out_dir"])
|
||||
log_path = os.path.join(out_dir, "train_log.csv")
|
||||
with open(log_path, "w", encoding="utf-8") as f:
|
||||
f.write("epoch,step,loss,loss_cont,loss_disc\n")
|
||||
with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
total_step = 0
|
||||
for epoch in range(int(config["epochs"])):
|
||||
for step, (x_cont, x_disc) in enumerate(
|
||||
for step, batch in enumerate(
|
||||
windowed_batches(
|
||||
config["data_path"],
|
||||
data_paths,
|
||||
cont_cols,
|
||||
disc_cols,
|
||||
vocab,
|
||||
@@ -140,8 +192,15 @@ def main():
|
||||
batch_size=int(config["batch_size"]),
|
||||
seq_len=int(config["seq_len"]),
|
||||
max_batches=int(config["max_batches"]),
|
||||
return_file_id=use_condition,
|
||||
)
|
||||
):
|
||||
if use_condition:
|
||||
x_cont, x_disc, cond = batch
|
||||
cond = cond.to(device)
|
||||
else:
|
||||
x_cont, x_disc = batch
|
||||
cond = None
|
||||
x_cont = x_cont.to(device)
|
||||
x_disc = x_disc.to(device)
|
||||
|
||||
@@ -153,21 +212,29 @@ def main():
|
||||
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
||||
x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"]))
|
||||
|
||||
eps_pred, logits = model(x_cont_t, x_disc_t, t)
|
||||
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond)
|
||||
|
||||
loss_cont = F.mse_loss(eps_pred, noise)
|
||||
loss_disc = 0.0
|
||||
loss_disc_count = 0
|
||||
for i, logit in enumerate(logits):
|
||||
if mask[:, :, i].any():
|
||||
loss_disc = loss_disc + F.cross_entropy(
|
||||
logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]]
|
||||
)
|
||||
loss_disc_count += 1
|
||||
if loss_disc_count > 0:
|
||||
loss_disc = loss_disc / loss_disc_count
|
||||
|
||||
lam = float(config["lambda"])
|
||||
loss = lam * loss_cont + (1 - lam) * loss_disc
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
if float(config.get("grad_clip", 0.0)) > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"]))
|
||||
opt.step()
|
||||
if ema is not None:
|
||||
ema.update(model)
|
||||
|
||||
if step % int(config["log_every"]) == 0:
|
||||
print("epoch", epoch, "step", step, "loss", float(loss))
|
||||
@@ -185,9 +252,13 @@ def main():
|
||||
"config": config,
|
||||
"step": total_step,
|
||||
}
|
||||
torch.save(ckpt, os.path.join(config["out_dir"], "model_ckpt.pt"))
|
||||
if ema is not None:
|
||||
ckpt["ema"] = ema.state_dict()
|
||||
torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt"))
|
||||
|
||||
torch.save(model.state_dict(), os.path.join(config["out_dir"], "model.pt"))
|
||||
torch.save(model.state_dict(), os.path.join(out_dir, "model.pt"))
|
||||
if ema is not None:
|
||||
torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user