diff --git a/example/README.md b/example/README.md index 57f60b4..3199277 100644 --- a/example/README.md +++ b/example/README.md @@ -1,7 +1,7 @@ # Example: HAI 21.03 Feature Split This folder contains a small, reproducible example that inspects the HAI 21.03 -CSV (train1) and produces a continuous/discrete split using a simple heuristic. +CSV (all train*.csv.gz files) and produces a continuous/discrete split using a simple heuristic. ## Files - analyze_hai21_03.py: reads a sample of the data and writes results. @@ -60,6 +60,12 @@ python example/run_pipeline.py --device auto - Set `device` in `example/config.json` to `auto` or `cuda` when moving to a GPU machine. - Attack label columns (`attack*`) are excluded from training and generation. - `time` column is always excluded from training and generation (optional for export only). +- EMA weights are saved as `model_ema.pt` and used by the pipeline for sampling. +- Gradients are clipped by default (`grad_clip` in `config.json`) to stabilize training. +- Discrete masking uses a cosine schedule for smoother corruption. +- Continuous sampling is clipped in normalized space each step for stability. +- Optional conditioning by file id (`train*.csv.gz`) is enabled by default for multi-file training. +- Continuous head can be bounded with `tanh` via `use_tanh_eps` in config. - The script only samples the first 5000 rows to stay fast. - `prepare_data.py` runs without PyTorch, but `train.py` and `sample.py` require it. - `train.py` and `sample.py` auto-select GPU if available; otherwise they fall back to CPU. diff --git a/example/config.json b/example/config.json index 38463d2..317301f 100644 --- a/example/config.json +++ b/example/config.json @@ -1,5 +1,6 @@ { "data_path": "../../dataset/hai/hai-21.03/train1.csv.gz", + "data_glob": "../../dataset/hai/hai-21.03/train*.csv.gz", "split_path": "./feature_split.json", "stats_path": "./results/cont_stats.json", "vocab_path": "./results/disc_vocab.json", @@ -14,5 +15,16 @@ "lr": 0.0005, "seed": 1337, "log_every": 10, - "ckpt_every": 50 + "ckpt_every": 50, + "ema_decay": 0.999, + "use_ema": true, + "clip_k": 5.0, + "grad_clip": 1.0, + "use_condition": true, + "condition_type": "file_id", + "cond_dim": 32, + "use_tanh_eps": true, + "eps_scale": 1.0, + "sample_batch_size": 8, + "sample_seq_len": 128 } diff --git a/example/data_utils.py b/example/data_utils.py index 685a9e4..333ee45 100755 --- a/example/data_utils.py +++ b/example/data_utils.py @@ -4,7 +4,7 @@ import csv import gzip import json -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple, Union @@ -13,15 +13,18 @@ def load_split(path: str) -> Dict[str, List[str]]: return json.load(f) -def iter_rows(path: str) -> Iterable[Dict[str, str]]: - with gzip.open(path, "rt", newline="") as f: - reader = csv.DictReader(f) - for row in reader: - yield row +def iter_rows(path_or_paths: Union[str, List[str]]) -> Iterable[Dict[str, str]]: + paths = [path_or_paths] if isinstance(path_or_paths, str) else list(path_or_paths) + for path in paths: + opener = gzip.open if str(path).endswith(".gz") else open + with opener(path, "rt", newline="") as f: + reader = csv.DictReader(f) + for row in reader: + yield row def compute_cont_stats( - path: str, + path: Union[str, List[str]], cont_cols: List[str], max_rows: Optional[int] = None, ) -> Tuple[Dict[str, float], Dict[str, float]]: @@ -52,7 +55,7 @@ def compute_cont_stats( def build_vocab( - path: str, + path: Union[str, List[str]], disc_cols: List[str], max_rows: Optional[int] = None, ) -> Dict[str, Dict[str, int]]: @@ -80,7 +83,7 @@ def normalize_cont(x, cont_cols: List[str], mean: Dict[str, float], std: Dict[st def windowed_batches( - path: str, + path: Union[str, List[str]], cont_cols: List[str], disc_cols: List[str], vocab: Dict[str, Dict[str, int]], @@ -89,10 +92,12 @@ def windowed_batches( batch_size: int, seq_len: int, max_batches: Optional[int] = None, + return_file_id: bool = False, ): import torch batch_cont = [] batch_disc = [] + batch_file = [] seq_cont = [] seq_disc = [] @@ -105,22 +110,34 @@ def windowed_batches( seq_disc = [] batches_yielded = 0 - for row in iter_rows(path): - cont_row = [float(row[c]) for c in cont_cols] - disc_row = [vocab[c].get(row[c], vocab[c][""]) for c in disc_cols] - seq_cont.append(cont_row) - seq_disc.append(disc_row) - if len(seq_cont) == seq_len: - flush_seq() - if len(batch_cont) == batch_size: - x_cont = torch.tensor(batch_cont, dtype=torch.float32) - x_disc = torch.tensor(batch_disc, dtype=torch.long) - x_cont = normalize_cont(x_cont, cont_cols, mean, std) - yield x_cont, x_disc - batch_cont = [] - batch_disc = [] - batches_yielded += 1 - if max_batches is not None and batches_yielded >= max_batches: - return + paths = [path] if isinstance(path, str) else list(path) + for file_id, p in enumerate(paths): + for row in iter_rows(p): + cont_row = [float(row[c]) for c in cont_cols] + disc_row = [vocab[c].get(row[c], vocab[c][""]) for c in disc_cols] + seq_cont.append(cont_row) + seq_disc.append(disc_row) + if len(seq_cont) == seq_len: + flush_seq() + if return_file_id: + batch_file.append(file_id) + if len(batch_cont) == batch_size: + x_cont = torch.tensor(batch_cont, dtype=torch.float32) + x_disc = torch.tensor(batch_disc, dtype=torch.long) + x_cont = normalize_cont(x_cont, cont_cols, mean, std) + if return_file_id: + x_file = torch.tensor(batch_file, dtype=torch.long) + yield x_cont, x_disc, x_file + else: + yield x_cont, x_disc + batch_cont = [] + batch_disc = [] + batch_file = [] + batches_yielded += 1 + if max_batches is not None and batches_yielded >= max_batches: + return + # drop partial sequence at file boundary + seq_cont = [] + seq_disc = [] # Drop last partial batch for simplicity diff --git a/example/export_samples.py b/example/export_samples.py index 47bd1d7..8e05e92 100644 --- a/example/export_samples.py +++ b/example/export_samples.py @@ -54,6 +54,7 @@ def parse_args(): base_dir = Path(__file__).resolve().parent repo_dir = base_dir.parent.parent parser.add_argument("--data-path", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz")) + parser.add_argument("--data-glob", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz")) parser.add_argument("--split-path", default=str(base_dir / "feature_split.json")) parser.add_argument("--stats-path", default=str(base_dir / "results" / "cont_stats.json")) parser.add_argument("--vocab-path", default=str(base_dir / "results" / "disc_vocab.json")) @@ -64,6 +65,11 @@ def parse_args(): parser.add_argument("--batch-size", type=int, default=2) parser.add_argument("--device", default="auto", help="cpu, cuda, or auto") parser.add_argument("--include-time", action="store_true", help="Include time column as a simple index") + parser.add_argument("--clip-k", type=float, default=5.0, help="Clip continuous values to mean±k*std") + parser.add_argument("--use-ema", action="store_true", help="Use EMA weights if available") + parser.add_argument("--config", default=None, help="Optional config_used.json to infer conditioning") + parser.add_argument("--condition-id", type=int, default=-1, help="Condition file id (0..N-1), -1=random") + parser.add_argument("--include-condition", action="store_true", help="Include condition id column in CSV") return parser.parse_args() @@ -75,6 +81,15 @@ def main(): if not os.path.exists(args.model_path): raise SystemExit("missing model file: %s" % args.model_path) + # resolve header source + data_path = args.data_path + if args.data_glob: + base = Path(args.data_glob).parent + pat = Path(args.data_glob).name + matches = sorted(base.glob(pat)) + if matches: + data_path = str(matches[0]) + split = load_split(args.split_path) time_col = split.get("time_column", "time") cont_cols = [c for c in split["continuous"] if c != time_col] @@ -89,8 +104,31 @@ def main(): vocab_sizes = [len(vocab[c]) for c in disc_cols] device = resolve_device(args.device) - model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(device) - model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True)) + cfg = {} + use_condition = False + cond_vocab_size = 0 + if args.config and os.path.exists(args.config): + with open(args.config, "r", encoding="utf-8") as f: + cfg = json.load(f) + use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id" + if use_condition: + base = Path(cfg.get("data_glob", args.data_glob)).parent + pat = Path(cfg.get("data_glob", args.data_glob)).name + cond_vocab_size = len(sorted(base.glob(pat))) + + model = HybridDiffusionModel( + cont_dim=len(cont_cols), + disc_vocab_sizes=vocab_sizes, + cond_vocab_size=cond_vocab_size if use_condition else 0, + cond_dim=int(cfg.get("cond_dim", 32)), + use_tanh_eps=bool(cfg.get("use_tanh_eps", False)), + eps_scale=float(cfg.get("eps_scale", 1.0)), + ).to(device) + if args.use_ema and os.path.exists(args.model_path.replace("model.pt", "model_ema.pt")): + ema_path = args.model_path.replace("model.pt", "model_ema.pt") + model.load_state_dict(torch.load(ema_path, map_location=device, weights_only=True)) + else: + model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True)) model.eval() betas = cosine_beta_schedule(args.timesteps).to(device) @@ -108,9 +146,20 @@ def main(): for i in range(len(disc_cols)): x_disc[:, :, i] = mask_tokens[i] + # condition id + cond = None + if use_condition: + if cond_vocab_size <= 0: + raise SystemExit("use_condition enabled but no files matched data_glob") + if args.condition_id < 0: + cond_id = torch.randint(0, cond_vocab_size, (args.batch_size,), device=device) + else: + cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long) + cond = cond_id + for t in reversed(range(args.timesteps)): t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long) - eps_pred, logits = model(x_cont, x_disc, t_batch) + eps_pred, logits = model(x_cont, x_disc, t_batch, cond) a_t = alphas[t] a_bar_t = alphas_cumprod[t] @@ -122,6 +171,8 @@ def main(): x_cont = mean_x + torch.sqrt(betas[t]) * noise else: x_cont = mean_x + if args.clip_k > 0: + x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k) for i, logit in enumerate(logits): if t == 0: @@ -136,15 +187,22 @@ def main(): ) x_disc[:, :, i][mask] = sampled[mask] + # move to CPU for export x_cont = x_cont.cpu() x_disc = x_disc.cpu() + # clip in normalized space to avoid extreme blow-up + if args.clip_k > 0: + x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k) + mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=x_cont.dtype) std_vec = torch.tensor([std[c] for c in cont_cols], dtype=x_cont.dtype) x_cont = x_cont * std_vec + mean_vec - header = read_header(args.data_path) + header = read_header(data_path) out_cols = [c for c in header if c != time_col or args.include_time] + if args.include_condition and use_condition: + out_cols = ["__cond_file_id"] + out_cols os.makedirs(os.path.dirname(args.out), exist_ok=True) with open(args.out, "w", newline="", encoding="utf-8") as f: @@ -155,6 +213,8 @@ def main(): for b in range(args.batch_size): for t in range(args.seq_len): row = {} + if args.include_condition and use_condition: + row["__cond_file_id"] = str(int(cond[b].item())) if cond is not None else "-1" if args.include_time and time_col in header: row[time_col] = str(row_index) for i, c in enumerate(cont_cols): diff --git a/example/hybrid_diffusion.py b/example/hybrid_diffusion.py index df4176e..506a7ad 100755 --- a/example/hybrid_diffusion.py +++ b/example/hybrid_diffusion.py @@ -36,9 +36,10 @@ def q_sample_discrete( mask_tokens: torch.Tensor, max_t: int, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Randomly mask discrete tokens with a linear schedule over t.""" + """Randomly mask discrete tokens with a cosine schedule over t.""" bsz = x0.size(0) - p = t.float() / float(max_t) + # cosine schedule: p(0)=0, p(max_t)=1 + p = 0.5 * (1.0 - torch.cos(math.pi * t.float() / float(max_t))) p = p.view(bsz, 1, 1) mask = torch.rand_like(x0.float()) < p x_masked = x0.clone() @@ -69,12 +70,24 @@ class HybridDiffusionModel(nn.Module): disc_vocab_sizes: List[int], time_dim: int = 64, hidden_dim: int = 256, + cond_vocab_size: int = 0, + cond_dim: int = 32, + use_tanh_eps: bool = False, + eps_scale: float = 1.0, ): super().__init__() self.cont_dim = cont_dim self.disc_vocab_sizes = disc_vocab_sizes self.time_embed = SinusoidalTimeEmbedding(time_dim) + self.use_tanh_eps = use_tanh_eps + self.eps_scale = eps_scale + + self.cond_vocab_size = cond_vocab_size + self.cond_dim = cond_dim + self.cond_embed = None + if cond_vocab_size and cond_vocab_size > 0: + self.cond_embed = nn.Embedding(cond_vocab_size, cond_dim) self.disc_embeds = nn.ModuleList([ nn.Embedding(vocab_size + 1, min(32, vocab_size * 2)) @@ -83,7 +96,8 @@ class HybridDiffusionModel(nn.Module): disc_embed_dim = sum(e.embedding_dim for e in self.disc_embeds) self.cont_proj = nn.Linear(cont_dim, cont_dim) - self.in_proj = nn.Linear(cont_dim + disc_embed_dim + time_dim, hidden_dim) + in_dim = cont_dim + disc_embed_dim + time_dim + (cond_dim if self.cond_embed is not None else 0) + self.in_proj = nn.Linear(in_dim, hidden_dim) self.backbone = nn.GRU(hidden_dim, hidden_dim, batch_first=True) self.cont_head = nn.Linear(hidden_dim, cont_dim) @@ -92,7 +106,7 @@ class HybridDiffusionModel(nn.Module): for vocab_size in disc_vocab_sizes ]) - def forward(self, x_cont: torch.Tensor, x_disc: torch.Tensor, t: torch.Tensor): + def forward(self, x_cont: torch.Tensor, x_disc: torch.Tensor, t: torch.Tensor, cond: torch.Tensor = None): """x_cont: (B,T,Cc), x_disc: (B,T,Cd) with integer tokens.""" time_emb = self.time_embed(t) time_emb = time_emb.unsqueeze(1).expand(-1, x_cont.size(1), -1) @@ -102,12 +116,23 @@ class HybridDiffusionModel(nn.Module): disc_embs.append(emb(x_disc[:, :, i])) disc_feat = torch.cat(disc_embs, dim=-1) + cond_feat = None + if self.cond_embed is not None: + if cond is None: + raise ValueError("cond is required when cond_vocab_size > 0") + cond_feat = self.cond_embed(cond).unsqueeze(1).expand(-1, x_cont.size(1), -1) + cont_feat = self.cont_proj(x_cont) - feat = torch.cat([cont_feat, disc_feat, time_emb], dim=-1) + parts = [cont_feat, disc_feat, time_emb] + if cond_feat is not None: + parts.append(cond_feat) + feat = torch.cat(parts, dim=-1) feat = self.in_proj(feat) out, _ = self.backbone(feat) eps_pred = self.cont_head(out) + if self.use_tanh_eps: + eps_pred = torch.tanh(eps_pred) * self.eps_scale logits = [head(out) for head in self.disc_heads] return eps_pred, logits diff --git a/example/prepare_data.py b/example/prepare_data.py index 6eb8e42..d28c227 100755 --- a/example/prepare_data.py +++ b/example/prepare_data.py @@ -10,7 +10,7 @@ from platform_utils import safe_path, ensure_dir BASE_DIR = Path(__file__).resolve().parent REPO_DIR = BASE_DIR.parent.parent -DATA_PATH = REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz" +DATA_GLOB = REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz" SPLIT_PATH = BASE_DIR / "feature_split.json" OUT_STATS = BASE_DIR / "results" / "cont_stats.json" OUT_VOCAB = BASE_DIR / "results" / "disc_vocab.json" @@ -22,8 +22,13 @@ def main(max_rows: Optional[int] = None): cont_cols = [c for c in split["continuous"] if c != time_col] disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col] - mean, std = compute_cont_stats(safe_path(DATA_PATH), cont_cols, max_rows=max_rows) - vocab = build_vocab(safe_path(DATA_PATH), disc_cols, max_rows=max_rows) + data_paths = sorted(Path(REPO_DIR / "dataset" / "hai" / "hai-21.03").glob("train*.csv.gz")) + if not data_paths: + raise SystemExit("no train files found under %s" % str(DATA_GLOB)) + data_paths = [safe_path(p) for p in data_paths] + + mean, std = compute_cont_stats(data_paths, cont_cols, max_rows=max_rows) + vocab = build_vocab(data_paths, disc_cols, max_rows=max_rows) ensure_dir(OUT_STATS.parent) with open(safe_path(OUT_STATS), "w", encoding="utf-8") as f: diff --git a/example/run_pipeline.py b/example/run_pipeline.py index e9e9486..0f1fd69 100644 --- a/example/run_pipeline.py +++ b/example/run_pipeline.py @@ -5,6 +5,7 @@ import argparse import subprocess import sys from pathlib import Path +import json from platform_utils import safe_path, is_windows @@ -40,6 +41,13 @@ def parse_args(): def main(): args = parse_args() base_dir = Path(__file__).resolve().parent + config_path = Path(args.config) + with open(config_path, "r", encoding="utf-8") as f: + cfg = json.load(f) + timesteps = cfg.get("timesteps", 200) + seq_len = cfg.get("sample_seq_len", cfg.get("seq_len", 64)) + batch_size = cfg.get("sample_batch_size", cfg.get("batch_size", 2)) + clip_k = cfg.get("clip_k", 5.0) run([sys.executable, str(base_dir / "prepare_data.py")]) run([sys.executable, str(base_dir / "train.py"), "--config", args.config, "--device", args.device]) run( @@ -49,6 +57,17 @@ def main(): "--include-time", "--device", args.device, + "--config", + str(config_path), + "--timesteps", + str(timesteps), + "--seq-len", + str(seq_len), + "--batch-size", + str(batch_size), + "--clip-k", + str(clip_k), + "--use-ema", ] ) run([sys.executable, str(base_dir / "evaluate_generated.py")]) diff --git a/example/sample.py b/example/sample.py index 071d002..14a5863 100755 --- a/example/sample.py +++ b/example/sample.py @@ -17,6 +17,7 @@ BASE_DIR = Path(__file__).resolve().parent SPLIT_PATH = BASE_DIR / "feature_split.json" VOCAB_PATH = BASE_DIR / "results" / "disc_vocab.json" MODEL_PATH = BASE_DIR / "results" / "model.pt" +CONFIG_PATH = BASE_DIR / "results" / "config_used.json" # 使用 platform_utils 中的 resolve_device 函数 @@ -25,6 +26,7 @@ DEVICE = resolve_device("auto") TIMESTEPS = 200 SEQ_LEN = 64 BATCH_SIZE = 2 +CLIP_K = 5.0 def load_vocab(): @@ -33,6 +35,19 @@ def load_vocab(): def main(): + cfg = {} + if CONFIG_PATH.exists(): + with open(str(CONFIG_PATH), "r", encoding="utf-8") as f: + cfg = json.load(f) + timesteps = int(cfg.get("timesteps", TIMESTEPS)) + seq_len = int(cfg.get("sample_seq_len", cfg.get("seq_len", SEQ_LEN))) + batch_size = int(cfg.get("sample_batch_size", cfg.get("batch_size", BATCH_SIZE))) + clip_k = float(cfg.get("clip_k", CLIP_K)) + use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id" + cond_dim = int(cfg.get("cond_dim", 32)) + use_tanh_eps = bool(cfg.get("use_tanh_eps", False)) + eps_scale = float(cfg.get("eps_scale", 1.0)) + split = load_split(str(SPLIT_PATH)) time_col = split.get("time_column", "time") cont_cols = [c for c in split["continuous"] if c != time_col] @@ -42,26 +57,46 @@ def main(): vocab_sizes = [len(vocab[c]) for c in disc_cols] print("device", DEVICE) - model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(DEVICE) + cond_vocab_size = 0 + if use_condition: + data_glob = cfg.get("data_glob") + if data_glob: + base = Path(data_glob).parent + pat = Path(data_glob).name + cond_vocab_size = len(sorted(base.glob(pat))) + model = HybridDiffusionModel( + cont_dim=len(cont_cols), + disc_vocab_sizes=vocab_sizes, + cond_vocab_size=cond_vocab_size, + cond_dim=cond_dim, + use_tanh_eps=use_tanh_eps, + eps_scale=eps_scale, + ).to(DEVICE) if MODEL_PATH.exists(): model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True)) model.eval() - betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE) + betas = cosine_beta_schedule(timesteps).to(DEVICE) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) - x_cont = torch.randn(BATCH_SIZE, SEQ_LEN, len(cont_cols), device=DEVICE) - x_disc = torch.full((BATCH_SIZE, SEQ_LEN, len(disc_cols)), 0, device=DEVICE, dtype=torch.long) + x_cont = torch.randn(batch_size, seq_len, len(cont_cols), device=DEVICE) + x_disc = torch.full((batch_size, seq_len, len(disc_cols)), 0, device=DEVICE, dtype=torch.long) mask_tokens = torch.tensor(vocab_sizes, device=DEVICE) # Initialize discrete with mask tokens for i in range(len(disc_cols)): x_disc[:, :, i] = mask_tokens[i] - for t in reversed(range(TIMESTEPS)): - t_batch = torch.full((BATCH_SIZE,), t, device=DEVICE, dtype=torch.long) - eps_pred, logits = model(x_cont, x_disc, t_batch) + cond = None + if use_condition: + if cond_vocab_size <= 0: + raise SystemExit("use_condition enabled but no files matched data_glob") + cond = torch.randint(0, cond_vocab_size, (batch_size,), device=DEVICE, dtype=torch.long) + + for t in reversed(range(timesteps)): + t_batch = torch.full((batch_size,), t, device=DEVICE, dtype=torch.long) + eps_pred, logits = model(x_cont, x_disc, t_batch, cond) # Continuous reverse step (DDPM): x_{t-1} mean a_t = alphas[t] @@ -74,6 +109,8 @@ def main(): x_cont = mean + torch.sqrt(betas[t]) * noise else: x_cont = mean + if clip_k > 0: + x_cont = torch.clamp(x_cont, -clip_k, clip_k) # Discrete: fill masked positions by sampling logits for i, logit in enumerate(logits): diff --git a/example/train.py b/example/train.py index 2ed80f8..e265116 100755 --- a/example/train.py +++ b/example/train.py @@ -26,6 +26,7 @@ REPO_DIR = BASE_DIR.parent.parent DEFAULTS = { "data_path": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz", + "data_glob": REPO_DIR / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz", "split_path": BASE_DIR / "feature_split.json", "stats_path": BASE_DIR / "results" / "cont_stats.json", "vocab_path": BASE_DIR / "results" / "disc_vocab.json", @@ -41,6 +42,15 @@ DEFAULTS = { "seed": 1337, "log_every": 10, "ckpt_every": 50, + "ema_decay": 0.999, + "use_ema": True, + "clip_k": 5.0, + "grad_clip": 1.0, + "use_condition": True, + "condition_type": "file_id", + "cond_dim": 32, + "use_tanh_eps": True, + "eps_scale": 1.0, } @@ -69,7 +79,7 @@ def parse_args(): def resolve_config_paths(config, base_dir: Path): - keys = ["data_path", "split_path", "stats_path", "vocab_path", "out_dir"] + keys = ["data_path", "data_glob", "split_path", "stats_path", "vocab_path", "out_dir"] for key in keys: if key in config: # 如果值是字符串,转换为Path对象 @@ -85,6 +95,26 @@ def resolve_config_paths(config, base_dir: Path): return config +class EMA: + def __init__(self, model, decay: float): + self.decay = decay + self.shadow = {} + for name, param in model.named_parameters(): + if param.requires_grad: + self.shadow[name] = param.detach().clone() + + def update(self, model): + with torch.no_grad(): + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + old = self.shadow[name] + self.shadow[name] = old * self.decay + param.detach() * (1.0 - self.decay) + + def state_dict(self): + return self.shadow + + def main(): args = parse_args() config = dict(DEFAULTS) @@ -113,25 +143,47 @@ def main(): vocab = load_json(config["vocab_path"])["vocab"] vocab_sizes = [len(vocab[c]) for c in disc_cols] + data_paths = None + if "data_glob" in config and config["data_glob"]: + data_paths = sorted(Path(config["data_glob"]).parent.glob(Path(config["data_glob"]).name)) + if data_paths: + data_paths = [safe_path(p) for p in data_paths] + if not data_paths: + data_paths = [safe_path(config["data_path"])] + + use_condition = bool(config.get("use_condition")) and config.get("condition_type") == "file_id" + cond_vocab_size = len(data_paths) if use_condition else 0 + device = resolve_device(str(config["device"])) print("device", device) - model = HybridDiffusionModel(cont_dim=len(cont_cols), disc_vocab_sizes=vocab_sizes).to(device) + model = HybridDiffusionModel( + cont_dim=len(cont_cols), + disc_vocab_sizes=vocab_sizes, + cond_vocab_size=cond_vocab_size, + cond_dim=int(config.get("cond_dim", 32)), + use_tanh_eps=bool(config.get("use_tanh_eps", False)), + eps_scale=float(config.get("eps_scale", 1.0)), + ).to(device) opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"])) + ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None betas = cosine_beta_schedule(int(config["timesteps"])).to(device) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) os.makedirs(config["out_dir"], exist_ok=True) - log_path = os.path.join(config["out_dir"], "train_log.csv") + out_dir = safe_path(config["out_dir"]) + log_path = os.path.join(out_dir, "train_log.csv") with open(log_path, "w", encoding="utf-8") as f: f.write("epoch,step,loss,loss_cont,loss_disc\n") + with open(os.path.join(out_dir, "config_used.json"), "w", encoding="utf-8") as f: + json.dump(config, f, indent=2) total_step = 0 for epoch in range(int(config["epochs"])): - for step, (x_cont, x_disc) in enumerate( + for step, batch in enumerate( windowed_batches( - config["data_path"], + data_paths, cont_cols, disc_cols, vocab, @@ -140,8 +192,15 @@ def main(): batch_size=int(config["batch_size"]), seq_len=int(config["seq_len"]), max_batches=int(config["max_batches"]), + return_file_id=use_condition, ) ): + if use_condition: + x_cont, x_disc, cond = batch + cond = cond.to(device) + else: + x_cont, x_disc = batch + cond = None x_cont = x_cont.to(device) x_disc = x_disc.to(device) @@ -153,21 +212,29 @@ def main(): mask_tokens = torch.tensor(vocab_sizes, device=device) x_disc_t, mask = q_sample_discrete(x_disc, t, mask_tokens, int(config["timesteps"])) - eps_pred, logits = model(x_cont_t, x_disc_t, t) + eps_pred, logits = model(x_cont_t, x_disc_t, t, cond) loss_cont = F.mse_loss(eps_pred, noise) loss_disc = 0.0 + loss_disc_count = 0 for i, logit in enumerate(logits): if mask[:, :, i].any(): loss_disc = loss_disc + F.cross_entropy( logit[mask[:, :, i]], x_disc[:, :, i][mask[:, :, i]] ) + loss_disc_count += 1 + if loss_disc_count > 0: + loss_disc = loss_disc / loss_disc_count lam = float(config["lambda"]) loss = lam * loss_cont + (1 - lam) * loss_disc opt.zero_grad() loss.backward() + if float(config.get("grad_clip", 0.0)) > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"])) opt.step() + if ema is not None: + ema.update(model) if step % int(config["log_every"]) == 0: print("epoch", epoch, "step", step, "loss", float(loss)) @@ -185,9 +252,13 @@ def main(): "config": config, "step": total_step, } - torch.save(ckpt, os.path.join(config["out_dir"], "model_ckpt.pt")) + if ema is not None: + ckpt["ema"] = ema.state_dict() + torch.save(ckpt, os.path.join(out_dir, "model_ckpt.pt")) - torch.save(model.state_dict(), os.path.join(config["out_dir"], "model.pt")) + torch.save(model.state_dict(), os.path.join(out_dir, "model.pt")) + if ema is not None: + torch.save(ema.state_dict(), os.path.join(out_dir, "model_ema.pt")) if __name__ == "__main__": diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..c287e31 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +torch +numpy +matplotlib