diff --git a/example/config_submission_full.json b/example/config_submission_full.json new file mode 100644 index 0000000..71d678c --- /dev/null +++ b/example/config_submission_full.json @@ -0,0 +1,81 @@ +{ + "data_path": "../../dataset/hai/hai-21.03/train1.csv.gz", + "data_glob": "../../dataset/hai/hai-21.03/train*.csv.gz", + "split_path": "./feature_split.json", + "stats_path": "./results/cont_stats.json", + "vocab_path": "./results/disc_vocab.json", + "out_dir": "./results", + "device": "auto", + "timesteps": 600, + "batch_size": 12, + "seq_len": 96, + "epochs": 10, + "max_batches": 4000, + "lambda": 0.7, + "lr": 0.0005, + "seed": 1337, + "log_every": 10, + "ckpt_every": 50, + "ema_decay": 0.999, + "use_ema": true, + "clip_k": 5.0, + "grad_clip": 1.0, + "use_condition": true, + "condition_type": "file_id", + "cond_dim": 32, + "use_tanh_eps": false, + "eps_scale": 1.0, + "model_time_dim": 128, + "model_hidden_dim": 512, + "model_num_layers": 2, + "model_dropout": 0.1, + "model_ff_mult": 2, + "model_pos_dim": 64, + "model_use_pos_embed": true, + "backbone_type": "transformer", + "transformer_num_layers": 3, + "transformer_nhead": 4, + "transformer_ff_dim": 512, + "transformer_dropout": 0.1, + "disc_mask_scale": 0.9, + "cont_loss_weighting": "inv_std", + "cont_loss_eps": 1e-6, + "cont_target": "x0", + "cont_clamp_x0": 5.0, + "use_quantile_transform": true, + "quantile_bins": 1001, + "cont_bound_mode": "none", + "cont_bound_strength": 2.0, + "cont_post_calibrate": true, + "cont_post_scale": {}, + "full_stats": true, + "type1_features": ["P1_B4002", "P2_MSD", "P4_HT_LD", "P1_B2004", "P1_B3004", "P1_B4022", "P1_B3005"], + "type2_features": ["P1_B4005"], + "type3_features": ["P1_PCV02Z", "P1_PCV01Z", "P1_PCV01D", "P1_FCV02Z", "P1_FCV03D", "P1_FCV03Z", "P1_LCV01D", "P1_LCV01Z"], + "type4_features": ["P1_PIT02", "P2_SIT02", "P1_FT03"], + "type5_features": ["P1_FT03Z", "P1_FT02Z"], + "type6_features": ["P4_HT_PO", "P2_24Vdc", "P2_HILout"], + "routing_type1_features": ["P1_B4022"], + "routing_type5_features": [], + "shuffle_buffer": 256, + "use_temporal_stage1": true, + "temporal_backbone": "transformer", + "temporal_hidden_dim": 256, + "temporal_num_layers": 1, + "temporal_dropout": 0.0, + "temporal_pos_dim": 64, + "temporal_use_pos_embed": true, + "temporal_transformer_num_layers": 2, + "temporal_transformer_nhead": 4, + "temporal_transformer_ff_dim": 256, + "temporal_transformer_dropout": 0.1, + "temporal_epochs": 3, + "temporal_lr": 0.001, + "quantile_loss_weight": 0.2, + "quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95], + "snr_weighted_loss": true, + "snr_gamma": 1.0, + "residual_stat_weight": 0.05, + "sample_batch_size": 4, + "sample_seq_len": 96 +} diff --git a/example/export_samples_resume.py b/example/export_samples_resume.py new file mode 100644 index 0000000..690d6ba --- /dev/null +++ b/example/export_samples_resume.py @@ -0,0 +1,440 @@ +#!/usr/bin/env python3 +"""Sample from a trained hybrid diffusion model with routing-aware export fixes.""" + +from __future__ import annotations + +import argparse +import csv +import gzip +import json +import os +from pathlib import Path + +import torch +import torch.nn.functional as F + +from data_utils import inverse_quantile_transform, load_split, normalize_cont, quantile_calibrate_to_real +from export_samples import build_inverse_vocab, load_stats, load_torch_state, read_header +from hybrid_diffusion import ( + HybridDiffusionModel, + TemporalGRUGenerator, + TemporalTransformerGenerator, + cosine_beta_schedule, +) +from platform_utils import resolve_device, resolve_path +from submission_type_utils import ( + denormalize_cont_tensor, + resolve_routing_features, + resolve_taxonomy_features, +) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Sample and export HAI feature sequences with routing-aware fixes.") + base_dir = Path(__file__).resolve().parent + repo_dir = base_dir.parent.parent + parser.add_argument("--data-path", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz")) + parser.add_argument("--data-glob", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz")) + parser.add_argument("--split-path", default=str(base_dir / "feature_split.json")) + parser.add_argument("--stats-path", default=str(base_dir / "results" / "cont_stats.json")) + parser.add_argument("--vocab-path", default=str(base_dir / "results" / "disc_vocab.json")) + parser.add_argument("--model-path", default=str(base_dir / "results" / "model.pt")) + parser.add_argument("--out", default=str(base_dir / "results" / "generated.csv")) + parser.add_argument("--timesteps", type=int, default=200) + parser.add_argument("--seq-len", type=int, default=64) + parser.add_argument("--batch-size", type=int, default=2) + parser.add_argument("--device", default="auto", help="cpu, cuda, or auto") + parser.add_argument("--include-time", action="store_true", help="Include time column as a simple index") + parser.add_argument("--clip-k", type=float, default=5.0, help="Clip continuous values to mean±k*std") + parser.add_argument("--use-ema", action="store_true", help="Use EMA weights if available") + parser.add_argument("--config", default=None, help="Optional config_used.json to infer conditioning") + parser.add_argument("--condition-id", type=int, default=-1, help="Condition file id (0..N-1), -1=random") + parser.add_argument("--include-condition", action="store_true", help="Include condition id column in CSV") + return parser.parse_args() + + +def main(): + args = parse_args() + base_dir = Path(__file__).resolve().parent + args.data_path = str(resolve_path(base_dir, args.data_path)) + args.data_glob = str(resolve_path(base_dir, args.data_glob)) if args.data_glob else "" + args.split_path = str(resolve_path(base_dir, args.split_path)) + args.stats_path = str(resolve_path(base_dir, args.stats_path)) + args.vocab_path = str(resolve_path(base_dir, args.vocab_path)) + args.model_path = str(resolve_path(base_dir, args.model_path)) + args.out = str(resolve_path(base_dir, args.out)) + + if not os.path.exists(args.model_path): + raise SystemExit(f"missing model file: {args.model_path}") + + data_path = args.data_path + if args.data_glob: + base = Path(args.data_glob).parent + pat = Path(args.data_glob).name + matches = sorted(base.glob(pat)) + if matches: + data_path = str(matches[0]) + + split = load_split(args.split_path) + time_col = split.get("time_column", "time") + cont_cols = [c for c in split["continuous"] if c != time_col] + disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col] + + stats = load_stats(args.stats_path) + mean = stats["mean"] + std = stats["std"] + vmin = stats.get("min", {}) + vmax = stats.get("max", {}) + int_like = stats.get("int_like", {}) + max_decimals = stats.get("max_decimals", {}) + transforms = stats.get("transform", {}) + quantile_probs = stats.get("quantile_probs") + quantile_values = stats.get("quantile_values") + quantile_raw_values = stats.get("quantile_raw_values") + + vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8")) + vocab = vocab_json["vocab"] + top_token = vocab_json.get("top_token", {}) + inv_vocab = build_inverse_vocab(vocab) + vocab_sizes = [len(vocab[c]) for c in disc_cols] + + device = resolve_device(args.device) + cfg = {} + use_condition = False + cond_vocab_size = 0 + if args.config: + args.config = str(resolve_path(base_dir, args.config)) + if args.config and os.path.exists(args.config): + with open(args.config, "r", encoding="utf-8") as f: + cfg = json.load(f) + use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id" + if use_condition: + cfg_base = Path(args.config).resolve().parent + cfg_glob = cfg.get("data_glob", args.data_glob) + cfg_glob = str(resolve_path(cfg_base, cfg_glob)) + base = Path(cfg_glob).parent + pat = Path(cfg_glob).name + cond_vocab_size = len(sorted(base.glob(pat))) + if cond_vocab_size <= 0: + raise SystemExit(f"use_condition enabled but no files matched data_glob: {cfg_glob}") + + cont_target = str(cfg.get("cont_target", "eps")) + cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0)) + use_quantile = bool(cfg.get("use_quantile_transform", False)) + cont_bound_mode = str(cfg.get("cont_bound_mode", "clamp")) + cont_bound_strength = float(cfg.get("cont_bound_strength", 1.0)) + cont_post_scale = cfg.get("cont_post_scale", {}) if isinstance(cfg.get("cont_post_scale", {}), dict) else {} + cont_post_calibrate = bool(cfg.get("cont_post_calibrate", False)) + + route_type1_cols = resolve_routing_features(cfg, cont_cols, "type1_features") + route_type5_cols = resolve_routing_features(cfg, cont_cols, "type5_features") + type4_cols = resolve_taxonomy_features(cfg, cont_cols, "type4_features") + model_cont_cols = [c for c in cont_cols if c not in route_type1_cols and c not in route_type5_cols] + + use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False)) + temporal_use_type1_cond = bool(cfg.get("temporal_use_type1_cond", False)) + temporal_focus_type4 = bool(cfg.get("temporal_focus_type4", False)) + temporal_exclude_type4 = bool(cfg.get("temporal_exclude_type4", False)) + temporal_backbone = str(cfg.get("temporal_backbone", "gru")) + temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256)) + temporal_num_layers = int(cfg.get("temporal_num_layers", 1)) + temporal_dropout = float(cfg.get("temporal_dropout", 0.0)) + temporal_pos_dim = int(cfg.get("temporal_pos_dim", 64)) + temporal_use_pos_embed = bool(cfg.get("temporal_use_pos_embed", True)) + temporal_transformer_num_layers = int(cfg.get("temporal_transformer_num_layers", 2)) + temporal_transformer_nhead = int(cfg.get("temporal_transformer_nhead", 4)) + temporal_transformer_ff_dim = int(cfg.get("temporal_transformer_ff_dim", 512)) + temporal_transformer_dropout = float(cfg.get("temporal_transformer_dropout", 0.1)) + backbone_type = str(cfg.get("backbone_type", "gru")) + transformer_num_layers = int(cfg.get("transformer_num_layers", 2)) + transformer_nhead = int(cfg.get("transformer_nhead", 4)) + transformer_ff_dim = int(cfg.get("transformer_ff_dim", 512)) + transformer_dropout = float(cfg.get("transformer_dropout", 0.1)) + + model = HybridDiffusionModel( + cont_dim=len(model_cont_cols), + disc_vocab_sizes=vocab_sizes, + time_dim=int(cfg.get("model_time_dim", 64)), + hidden_dim=int(cfg.get("model_hidden_dim", 256)), + num_layers=int(cfg.get("model_num_layers", 1)), + dropout=float(cfg.get("model_dropout", 0.0)), + ff_mult=int(cfg.get("model_ff_mult", 2)), + pos_dim=int(cfg.get("model_pos_dim", 64)), + use_pos_embed=bool(cfg.get("model_use_pos_embed", True)), + backbone_type=backbone_type, + transformer_num_layers=transformer_num_layers, + transformer_nhead=transformer_nhead, + transformer_ff_dim=transformer_ff_dim, + transformer_dropout=transformer_dropout, + cond_cont_dim=len(route_type1_cols), + cond_vocab_size=cond_vocab_size if use_condition else 0, + cond_dim=int(cfg.get("cond_dim", 32)), + use_tanh_eps=bool(cfg.get("use_tanh_eps", False)), + eps_scale=float(cfg.get("eps_scale", 1.0)), + ).to(device) + if args.use_ema and os.path.exists(args.model_path.replace("model.pt", "model_ema.pt")): + ema_path = args.model_path.replace("model.pt", "model_ema.pt") + model.load_state_dict(load_torch_state(ema_path, device)) + else: + model.load_state_dict(load_torch_state(args.model_path, device)) + model.eval() + + temporal_model = None + if use_temporal_stage1: + temporal_path = Path(args.model_path).with_name("temporal.pt") + if not temporal_path.exists(): + raise SystemExit(f"missing temporal model file: {temporal_path}") + temporal_state = load_torch_state(str(temporal_path), device) + temporal_cond_dim = len(route_type1_cols) if (temporal_use_type1_cond and route_type1_cols) else 0 + if isinstance(temporal_state, dict): + if "in_proj.weight" in temporal_state: + try: + temporal_cond_dim = max(0, int(temporal_state["in_proj.weight"].shape[1]) - len(model_cont_cols)) + except Exception: + pass + elif "gru.weight_ih_l0" in temporal_state: + try: + temporal_cond_dim = max(0, int(temporal_state["gru.weight_ih_l0"].shape[1]) - len(model_cont_cols)) + except Exception: + pass + if temporal_backbone == "transformer": + temporal_model = TemporalTransformerGenerator( + input_dim=len(model_cont_cols), + hidden_dim=temporal_hidden_dim, + num_layers=temporal_transformer_num_layers, + nhead=temporal_transformer_nhead, + ff_dim=temporal_transformer_ff_dim, + dropout=temporal_transformer_dropout, + pos_dim=temporal_pos_dim, + use_pos_embed=temporal_use_pos_embed, + cond_dim=temporal_cond_dim, + ).to(device) + else: + temporal_model = TemporalGRUGenerator( + input_dim=len(model_cont_cols), + hidden_dim=temporal_hidden_dim, + num_layers=temporal_num_layers, + dropout=temporal_dropout, + cond_dim=temporal_cond_dim, + ).to(device) + temporal_model.load_state_dict(temporal_state) + temporal_model.eval() + + betas = cosine_beta_schedule(args.timesteps).to(device) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + + x_cont = torch.randn(args.batch_size, args.seq_len, len(model_cont_cols), device=device) + x_disc = torch.full((args.batch_size, args.seq_len, len(disc_cols)), 0, device=device, dtype=torch.long) + mask_tokens = torch.tensor(vocab_sizes, device=device) + for i in range(len(disc_cols)): + x_disc[:, :, i] = mask_tokens[i] + + cond = None + if use_condition: + if cond_vocab_size <= 0: + raise SystemExit("use_condition enabled but no files matched data_glob") + if args.condition_id < 0: + cond_id = torch.randint(0, cond_vocab_size, (args.batch_size,), device=device) + else: + cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long) + cond = cond_id + + cond_cont = None + if route_type1_cols: + ref_glob = cfg.get("data_glob") or args.data_glob + if ref_glob: + ref_glob = str(resolve_path(Path(args.config).parent, ref_glob)) if args.config else ref_glob + base = Path(ref_glob).parent + pat = Path(ref_glob).name + refs = sorted(base.glob(pat)) + if refs: + ref_path = refs[0] + ref_rows = [] + with gzip.open(ref_path, "rt", newline="") as fh: + reader = csv.DictReader(fh) + for row in reader: + ref_rows.append(row) + if len(ref_rows) >= args.seq_len: + seq = ref_rows[: args.seq_len] + cond_cont = torch.zeros(args.batch_size, args.seq_len, len(route_type1_cols), device=device) + for t, row in enumerate(seq): + for i, c in enumerate(route_type1_cols): + cond_cont[:, t, i] = float(row[c]) + cond_cont = normalize_cont( + cond_cont, + route_type1_cols, + mean, + std, + transforms=transforms, + quantile_probs=quantile_probs, + quantile_values=quantile_values, + use_quantile=use_quantile, + ) + + trend = None + if temporal_model is not None: + trend = temporal_model.generate(args.batch_size, args.seq_len, device, cond_cont=cond_cont) + if temporal_focus_type4 and type4_cols: + type4_model_idx = [model_cont_cols.index(c) for c in type4_cols if c in model_cont_cols] + if type4_model_idx: + trend_mask = torch.zeros(1, 1, len(model_cont_cols), device=device, dtype=trend.dtype) + trend_mask[:, :, type4_model_idx] = 1.0 + trend = trend * trend_mask + elif temporal_exclude_type4 and type4_cols: + type4_model_idx = [model_cont_cols.index(c) for c in type4_cols if c in model_cont_cols] + if type4_model_idx: + trend_mask = torch.ones(1, 1, len(model_cont_cols), device=device, dtype=trend.dtype) + trend_mask[:, :, type4_model_idx] = 0.0 + trend = trend * trend_mask + + for t in reversed(range(args.timesteps)): + t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long) + eps_pred, logits = model(x_cont, x_disc, t_batch, cond, cond_cont=cond_cont) + + a_t = alphas[t] + a_bar_t = alphas_cumprod[t] + + if cont_target == "x0": + x0_pred = eps_pred + if cont_clamp_x0 > 0: + x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0) + eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t) + coef1 = 1.0 / torch.sqrt(a_t) + coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t) + mean_x = coef1 * (x_cont - coef2 * eps_pred) + if t > 0: + noise = torch.randn_like(x_cont) + x_cont = mean_x + torch.sqrt(betas[t]) * noise + else: + x_cont = mean_x + if args.clip_k > 0: + x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k) + + for i, logit in enumerate(logits): + if t == 0: + probs = F.softmax(logit, dim=-1) + x_disc[:, :, i] = torch.argmax(probs, dim=-1) + else: + mask = x_disc[:, :, i] == mask_tokens[i] + if mask.any(): + probs = F.softmax(logit, dim=-1) + sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view( + args.batch_size, args.seq_len + ) + x_disc[:, :, i][mask] = sampled[mask] + + if trend is not None: + x_cont = x_cont + trend + x_cont = x_cont.cpu() + x_disc = x_disc.cpu() + + if args.clip_k > 0: + x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k) + + if use_quantile: + q_vals = {c: quantile_values[c] for c in model_cont_cols} + x_cont = inverse_quantile_transform(x_cont, model_cont_cols, quantile_probs, q_vals) + else: + mean_vec = torch.tensor([mean[c] for c in model_cont_cols], dtype=x_cont.dtype) + std_vec = torch.tensor([std[c] for c in model_cont_cols], dtype=x_cont.dtype) + x_cont = x_cont * std_vec + mean_vec + for i, c in enumerate(model_cont_cols): + if transforms.get(c) == "log1p": + x_cont[:, :, i] = torch.expm1(x_cont[:, :, i]) + if cont_post_calibrate and quantile_raw_values and quantile_probs: + q_raw = {c: quantile_raw_values[c] for c in model_cont_cols} + x_cont = quantile_calibrate_to_real(x_cont, model_cont_cols, quantile_probs, q_raw) + if vmin and vmax: + for i, c in enumerate(model_cont_cols): + lo = vmin.get(c, None) + hi = vmax.get(c, None) + if lo is None or hi is None: + continue + lo = float(lo) + hi = float(hi) + if cont_bound_mode == "none": + continue + if cont_bound_mode == "sigmoid": + x_cont[:, :, i] = lo + (hi - lo) * torch.sigmoid(x_cont[:, :, i]) + elif cont_bound_mode == "soft_tanh": + mid = 0.5 * (lo + hi) + half = 0.5 * (hi - lo) + denom = cont_bound_strength if cont_bound_strength > 0 else 1.0 + x_cont[:, :, i] = mid + half * torch.tanh(x_cont[:, :, i] / denom) + else: + x_cont[:, :, i] = torch.clamp(x_cont[:, :, i], lo, hi) + + if cont_post_scale: + for i, c in enumerate(model_cont_cols): + if c in cont_post_scale: + try: + scale = float(cont_post_scale[c]) + except Exception: + scale = 1.0 + x_cont[:, :, i] = x_cont[:, :, i] * scale + + full_cont = torch.zeros(args.batch_size, args.seq_len, len(cont_cols), dtype=x_cont.dtype) + for i, c in enumerate(model_cont_cols): + full_idx = cont_cols.index(c) + full_cont[:, :, full_idx] = x_cont[:, :, i] + if cond_cont is not None and route_type1_cols: + cond_denorm = denormalize_cont_tensor( + cond_cont.cpu(), + route_type1_cols, + mean, + std, + transforms=transforms, + quantile_probs=quantile_probs, + quantile_values=quantile_values, + use_quantile=use_quantile, + ) + for i, c in enumerate(route_type1_cols): + full_idx = cont_cols.index(c) + full_cont[:, :, full_idx] = cond_denorm[:, :, i] + for c in route_type5_cols: + if c.endswith("Z"): + base = c[:-1] + if base in cont_cols: + bidx = cont_cols.index(base) + cidx = cont_cols.index(c) + full_cont[:, :, cidx] = full_cont[:, :, bidx] + + header = read_header(data_path) + out_cols = [c for c in header if c != time_col or args.include_time] + if args.include_condition and use_condition: + out_cols = ["__cond_file_id"] + out_cols + + os.makedirs(os.path.dirname(args.out), exist_ok=True) + with open(args.out, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=out_cols) + writer.writeheader() + + row_index = 0 + for b in range(args.batch_size): + for t in range(args.seq_len): + row = {} + if args.include_condition and use_condition: + row["__cond_file_id"] = str(int(cond[b].item())) if cond is not None else "-1" + if args.include_time and time_col in header: + row[time_col] = str(row_index) + for i, c in enumerate(cont_cols): + val = float(full_cont[b, t, i]) + if int_like.get(c, False): + row[c] = str(int(round(val))) + else: + dec = int(max_decimals.get(c, 6)) + fmt = ("%%.%df" % dec) if dec > 0 else "%.0f" + row[c] = fmt % val + for i, c in enumerate(disc_cols): + tok_idx = int(x_disc[b, t, i]) + tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "" + if tok == "" and c in top_token: + tok = top_token[c] + row[c] = tok + writer.writerow(row) + row_index += 1 + + +if __name__ == "__main__": + main() diff --git a/example/run_submission_full.sh b/example/run_submission_full.sh new file mode 100644 index 0000000..b1fc202 --- /dev/null +++ b/example/run_submission_full.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +RUN_DIR="${RUN_DIR:-$SCRIPT_DIR/results/submission_full}" +LOG_DIR="$RUN_DIR/logs" +mkdir -p "$LOG_DIR" + +STAMP="$(date '+%Y%m%d-%H%M%S')" +LOG_FILE="$LOG_DIR/pipeline-$STAMP.log" + +echo "[run_submission_full] run_dir=$RUN_DIR" +echo "[run_submission_full] log_file=$LOG_FILE" + +python "$SCRIPT_DIR/run_submission_resume.py" \ + --config "$SCRIPT_DIR/config_submission_full.json" \ + --device "${DEVICE:-cuda}" \ + --run-dir "$RUN_DIR" \ + "$@" 2>&1 | tee -a "$LOG_FILE" diff --git a/example/run_submission_resume.py b/example/run_submission_resume.py new file mode 100644 index 0000000..3d88207 --- /dev/null +++ b/example/run_submission_resume.py @@ -0,0 +1,399 @@ +#!/usr/bin/env python3 +"""One-command full pipeline runner with safe resume and stage skipping.""" + +from __future__ import annotations + +import argparse +import json +import subprocess +import sys +from pathlib import Path +from typing import Dict, List + +from platform_utils import is_windows, safe_path + + +def run(cmd: List[str]) -> None: + print("running:", " ".join(cmd)) + cmd = [safe_path(arg) for arg in cmd] + if is_windows(): + subprocess.run(cmd, check=True, shell=False) + else: + subprocess.run(cmd, check=True) + + +def parse_args(): + base_dir = Path(__file__).resolve().parent + parser = argparse.ArgumentParser(description="Run prepare -> train -> export -> eval with resume-aware staging.") + parser.add_argument("--config", default=str(base_dir / "config_submission_full.json")) + parser.add_argument("--device", default="auto") + parser.add_argument("--run-dir", default=str(base_dir / "results" / "submission_full")) + parser.add_argument("--reference", default="") + parser.add_argument("--no-resume", action="store_true", help="Do not auto-skip completed stages or resume from ckpt.") + parser.add_argument("--skip-prepare", action="store_true") + parser.add_argument("--skip-train", action="store_true") + parser.add_argument("--skip-export", action="store_true") + parser.add_argument("--skip-eval", action="store_true") + parser.add_argument("--skip-comprehensive-eval", action="store_true") + parser.add_argument("--skip-postprocess", action="store_true") + parser.add_argument("--skip-post-eval", action="store_true") + parser.add_argument("--skip-diagnostics", action="store_true") + return parser.parse_args() + + +def load_state(path: Path) -> Dict[str, str]: + if not path.exists(): + return {} + try: + return json.loads(path.read_text(encoding="utf-8")) + except Exception: + return {} + + +def save_state(path: Path, state: Dict[str, str]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(state, indent=2, sort_keys=True), encoding="utf-8") + + +def stage_complete(state: Dict[str, str], stage: str, outputs: List[Path], resume: bool) -> bool: + if not resume: + return False + if outputs and all(p.exists() for p in outputs): + return True + return state.get(stage) == "done" + + +def main(): + args = parse_args() + base_dir = Path(__file__).resolve().parent + config_path = Path(args.config) + if not config_path.is_absolute(): + config_path = (base_dir / config_path).resolve() + run_dir = Path(args.run_dir) + if not run_dir.is_absolute(): + run_dir = (base_dir / run_dir).resolve() + run_dir.mkdir(parents=True, exist_ok=True) + + cfg = json.loads(config_path.read_text(encoding="utf-8")) + cfg_base = config_path.parent + + def abs_cfg_like(value: str) -> str: + p = Path(value) + if p.is_absolute(): + return str(p) + if any(ch in value for ch in ["*", "?", "["]): + return str(cfg_base / p) + return str((cfg_base / p).resolve()) + + ref = args.reference or cfg.get("data_glob") or cfg.get("data_path") or "" + if ref: + ref = abs_cfg_like(str(ref)) + + timesteps = int(cfg.get("timesteps", 200)) + seq_len = int(cfg.get("sample_seq_len", cfg.get("seq_len", 64))) + batch_size = int(cfg.get("sample_batch_size", cfg.get("batch_size", 2))) + clip_k = float(cfg.get("clip_k", 5.0)) + split_path = abs_cfg_like(str(cfg.get("split_path", "./feature_split.json"))) + stats_path = abs_cfg_like(str(cfg.get("stats_path", "./results/cont_stats.json"))) + vocab_path = abs_cfg_like(str(cfg.get("vocab_path", "./results/disc_vocab.json"))) + data_path = abs_cfg_like(str(cfg.get("data_path", ""))) if cfg.get("data_path") else "" + data_glob = abs_cfg_like(str(cfg.get("data_glob", ""))) if cfg.get("data_glob") else "" + + state_path = run_dir / "pipeline_state.json" + state = load_state(state_path) + resume = not args.no_resume + cfg_for_steps = run_dir / "config_used.json" + + stage_defs = [] + if not args.skip_prepare: + stage_defs.append( + ( + "prepare", + [Path(stats_path), Path(vocab_path)], + [sys.executable, str(base_dir / "prepare_data.py"), "--config", str(config_path)], + ) + ) + if not args.skip_train: + train_cmd = [ + sys.executable, + str(base_dir / "train_resume.py"), + "--config", + str(config_path), + "--device", + args.device, + "--out-dir", + str(run_dir), + "--seed", + str(int(cfg.get("seed", 1337))), + ] + if resume: + train_cmd.append("--resume") + stage_defs.append(("train", [run_dir / "model.pt"], train_cmd)) + if not args.skip_export: + stage_defs.append( + ( + "export", + [run_dir / "generated.csv"], + [ + sys.executable, + str(base_dir / "export_samples_resume.py"), + "--include-time", + "--device", + args.device, + "--config", + str(cfg_for_steps if cfg_for_steps.exists() else config_path), + "--data-path", + str(data_path), + "--data-glob", + str(data_glob), + "--split-path", + str(split_path), + "--stats-path", + str(stats_path), + "--vocab-path", + str(vocab_path), + "--model-path", + str(run_dir / "model.pt"), + "--out", + str(run_dir / "generated.csv"), + "--timesteps", + str(timesteps), + "--seq-len", + str(seq_len), + "--batch-size", + str(batch_size), + "--clip-k", + str(clip_k), + "--use-ema", + ], + ) + ) + if not args.skip_eval: + eval_cmd = [ + sys.executable, + str(base_dir / "evaluate_generated.py"), + "--generated", + str(run_dir / "generated.csv"), + "--split", + str(split_path), + "--stats", + str(stats_path), + "--vocab", + str(vocab_path), + "--out", + str(run_dir / "eval.json"), + ] + if ref: + eval_cmd += ["--reference", str(ref)] + stage_defs.append(("eval", [run_dir / "eval.json"], eval_cmd)) + if not args.skip_comprehensive_eval: + stage_defs.append( + ( + "comprehensive_eval", + [run_dir / "comprehensive_eval.json"], + [ + sys.executable, + str(base_dir / "evaluate_comprehensive.py"), + "--generated", + str(run_dir / "generated.csv"), + "--reference", + str(config_path), + "--config", + str(cfg_for_steps if cfg_for_steps.exists() else config_path), + "--split", + str(split_path), + "--stats", + str(stats_path), + "--vocab", + str(vocab_path), + "--out", + str(run_dir / "comprehensive_eval.json"), + "--device", + args.device, + ], + ) + ) + if not args.skip_postprocess: + post_cmd = [ + sys.executable, + str(base_dir / "postprocess_types.py"), + "--generated", + str(run_dir / "generated.csv"), + "--config", + str(cfg_for_steps if cfg_for_steps.exists() else config_path), + "--out", + str(run_dir / "generated_post.csv"), + "--seed", + str(int(cfg.get("seed", 1337))), + ] + if ref: + post_cmd += ["--reference", str(ref)] + stage_defs.append(("postprocess", [run_dir / "generated_post.csv"], post_cmd)) + if not args.skip_post_eval: + post_eval_cmd = [ + sys.executable, + str(base_dir / "evaluate_generated.py"), + "--generated", + str(run_dir / "generated_post.csv"), + "--split", + str(split_path), + "--stats", + str(stats_path), + "--vocab", + str(vocab_path), + "--out", + str(run_dir / "eval_post.json"), + ] + if ref: + post_eval_cmd += ["--reference", str(ref)] + stage_defs.append(("post_eval", [run_dir / "eval_post.json"], post_eval_cmd)) + if not args.skip_comprehensive_eval: + stage_defs.append( + ( + "comprehensive_post_eval", + [run_dir / "comprehensive_eval_post.json"], + [ + sys.executable, + str(base_dir / "evaluate_comprehensive.py"), + "--generated", + str(run_dir / "generated_post.csv"), + "--reference", + str(config_path), + "--config", + str(cfg_for_steps if cfg_for_steps.exists() else config_path), + "--split", + str(split_path), + "--stats", + str(stats_path), + "--vocab", + str(vocab_path), + "--out", + str(run_dir / "comprehensive_eval_post.json"), + "--device", + args.device, + ], + ) + ) + if not args.skip_diagnostics: + stage_defs.extend( + [ + ( + "filtered_metrics", + [run_dir / "filtered_metrics.json"], + [ + sys.executable, + str(base_dir / "filtered_metrics.py"), + "--eval", + str(run_dir / "eval.json"), + "--out", + str(run_dir / "filtered_metrics.json"), + ], + ), + ( + "ranked_ks", + [run_dir / "ranked_ks.csv"], + [ + sys.executable, + str(base_dir / "ranked_ks.py"), + "--eval", + str(run_dir / "eval.json"), + "--out", + str(run_dir / "ranked_ks.csv"), + ], + ), + ( + "program_stats", + [run_dir / "program_stats.json"], + [ + sys.executable, + str(base_dir / "program_stats.py"), + "--generated", + str(run_dir / "generated.csv"), + "--reference", + str(config_path), + "--config", + str(cfg_for_steps if cfg_for_steps.exists() else config_path), + ], + ), + ( + "controller_stats", + [run_dir / "controller_stats.json"], + [ + sys.executable, + str(base_dir / "controller_stats.py"), + "--generated", + str(run_dir / "generated.csv"), + "--reference", + str(config_path), + "--config", + str(cfg_for_steps if cfg_for_steps.exists() else config_path), + ], + ), + ( + "actuator_stats", + [run_dir / "actuator_stats.json"], + [ + sys.executable, + str(base_dir / "actuator_stats.py"), + "--generated", + str(run_dir / "generated.csv"), + "--reference", + str(config_path), + "--config", + str(cfg_for_steps if cfg_for_steps.exists() else config_path), + ], + ), + ( + "pv_stats", + [run_dir / "pv_stats.json"], + [ + sys.executable, + str(base_dir / "pv_stats.py"), + "--generated", + str(run_dir / "generated.csv"), + "--reference", + str(config_path), + "--config", + str(cfg_for_steps if cfg_for_steps.exists() else config_path), + ], + ), + ( + "aux_stats", + [run_dir / "aux_stats.json"], + [ + sys.executable, + str(base_dir / "aux_stats.py"), + "--generated", + str(run_dir / "generated.csv"), + "--reference", + str(config_path), + "--config", + str(cfg_for_steps if cfg_for_steps.exists() else config_path), + ], + ), + ] + ) + + command_log = run_dir / "run_commands.txt" + if not command_log.exists(): + command_log.write_text("", encoding="utf-8") + + for stage, outputs, cmd in stage_defs: + if stage_complete(state, stage, outputs, resume): + print(f"skip_stage {stage}: outputs already present") + state[stage] = "done" + save_state(state_path, state) + continue + state[stage] = "running" + save_state(state_path, state) + with command_log.open("a", encoding="utf-8") as fh: + fh.write(stage + ": " + " ".join(cmd) + "\n") + run(cmd) + state[stage] = "done" + save_state(state_path, state) + + print(f"pipeline_complete run_dir={run_dir}") + + +if __name__ == "__main__": + main() diff --git a/example/submission_type_utils.py b/example/submission_type_utils.py new file mode 100644 index 0000000..aa2d94f --- /dev/null +++ b/example/submission_type_utils.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +"""Helpers for keeping type taxonomy and routing policy separate.""" + +from __future__ import annotations + +from typing import Dict, List, Optional + +import torch + +from data_utils import inverse_quantile_transform + + +def resolve_taxonomy_features(config: Dict, cont_cols: List[str], base_key: str) -> List[str]: + feats = config.get(base_key, []) or [] + return [c for c in feats if c in cont_cols] + + +def resolve_routing_features(config: Dict, cont_cols: List[str], base_key: str) -> List[str]: + feats = config.get(f"routing_{base_key}", config.get(base_key, [])) or [] + return [c for c in feats if c in cont_cols] + + +def denormalize_cont_tensor( + x: torch.Tensor, + cont_cols: List[str], + mean: Dict[str, float], + std: Dict[str, float], + transforms: Optional[Dict[str, str]] = None, + quantile_probs: Optional[List[float]] = None, + quantile_values: Optional[Dict[str, List[float]]] = None, + use_quantile: bool = False, +) -> torch.Tensor: + if x is None: + raise ValueError("x must not be None") + if not cont_cols: + return x.clone() + + out = x.clone() + if use_quantile: + if not quantile_probs or not quantile_values: + raise ValueError("use_quantile=True but quantile stats are missing") + q_vals = {c: quantile_values[c] for c in cont_cols} + out = inverse_quantile_transform(out, cont_cols, quantile_probs, q_vals) + else: + mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=out.dtype, device=out.device) + std_vec = torch.tensor([std[c] for c in cont_cols], dtype=out.dtype, device=out.device) + out = out * std_vec + mean_vec + + if transforms: + for i, c in enumerate(cont_cols): + if transforms.get(c) == "log1p": + out[:, :, i] = torch.expm1(out[:, :, i]) + return out diff --git a/example/train_resume.py b/example/train_resume.py new file mode 100644 index 0000000..be4f8d5 --- /dev/null +++ b/example/train_resume.py @@ -0,0 +1,534 @@ +#!/usr/bin/env python3 +"""Train hybrid diffusion with checkpoint resume and selective type-aware routing.""" + +from __future__ import annotations + +import argparse +import json +import os +from pathlib import Path +from typing import Dict + +import torch +import torch.nn.functional as F + +from data_utils import load_split, windowed_batches +from hybrid_diffusion import ( + HybridDiffusionModel, + TemporalGRUGenerator, + TemporalTransformerGenerator, + cosine_beta_schedule, + q_sample_continuous, + q_sample_discrete, +) +from platform_utils import resolve_device, resolve_path, safe_path +from submission_type_utils import resolve_routing_features, resolve_taxonomy_features +from train import DEFAULTS, EMA, load_json, resolve_config_paths, set_seed + +BASE_DIR = Path(__file__).resolve().parent + + +def load_torch_state(path: str, device: str): + try: + return torch.load(path, map_location=device, weights_only=True) + except TypeError: + return torch.load(path, map_location=device) + + +def atomic_torch_save(obj, path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(path.suffix + ".tmp") + torch.save(obj, str(tmp)) + os.replace(str(tmp), str(path)) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train hybrid diffusion on HAI with resume support.") + parser.add_argument("--config", default=None, help="Path to JSON config.") + parser.add_argument("--device", default="auto", help="cpu, cuda, or auto") + parser.add_argument("--out-dir", default=None, help="Override output directory") + parser.add_argument("--seed", type=int, default=None, help="Override random seed") + parser.add_argument("--temporal-only", action="store_true", help="Only train temporal stage-1 and exit.") + parser.add_argument("--resume", action="store_true", help="Resume from checkpoint in out-dir if present.") + parser.add_argument("--resume-ckpt", default=None, help="Optional explicit model checkpoint path.") + return parser.parse_args() + + +def build_temporal_model(config: Dict, model_cont_cols, temporal_cond_dim: int, device: str): + temporal_backbone = str(config.get("temporal_backbone", "gru")) + if temporal_backbone == "transformer": + return TemporalTransformerGenerator( + input_dim=len(model_cont_cols), + hidden_dim=int(config.get("temporal_hidden_dim", 256)), + num_layers=int(config.get("temporal_transformer_num_layers", 2)), + nhead=int(config.get("temporal_transformer_nhead", 4)), + ff_dim=int(config.get("temporal_transformer_ff_dim", 512)), + dropout=float(config.get("temporal_transformer_dropout", 0.1)), + pos_dim=int(config.get("temporal_pos_dim", 64)), + use_pos_embed=bool(config.get("temporal_use_pos_embed", True)), + cond_dim=temporal_cond_dim, + ).to(device) + return TemporalGRUGenerator( + input_dim=len(model_cont_cols), + hidden_dim=int(config.get("temporal_hidden_dim", 256)), + num_layers=int(config.get("temporal_num_layers", 1)), + dropout=float(config.get("temporal_dropout", 0.0)), + cond_dim=temporal_cond_dim, + ).to(device) + + +def init_or_append_log(log_path: Path, resume: bool) -> None: + if resume and log_path.exists(): + return + with open(log_path, "w", encoding="utf-8") as f: + f.write("epoch,step,loss,loss_cont,loss_disc\n") + + +def main(): + args = parse_args() + if args.config: + print("using_config", str(Path(args.config).resolve())) + config = dict(DEFAULTS) + if args.config: + cfg_path = Path(args.config).resolve() + config.update(load_json(str(cfg_path))) + config = resolve_config_paths(config, cfg_path.parent) + else: + config = resolve_config_paths(config, BASE_DIR) + + if args.device != "auto": + config["device"] = args.device + if args.out_dir: + out_dir = Path(args.out_dir) + if not out_dir.is_absolute(): + base = Path(args.config).resolve().parent if args.config else BASE_DIR + out_dir = resolve_path(base, out_dir) + config["out_dir"] = str(out_dir) + if args.seed is not None: + config["seed"] = int(args.seed) + if bool(args.temporal_only): + config["use_temporal_stage1"] = True + config["epochs"] = 0 + + set_seed(int(config["seed"])) + + split = load_split(config["split_path"]) + time_col = split.get("time_column", "time") + cont_cols = [c for c in split["continuous"] if c != time_col] + disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col] + + type1_cols = resolve_routing_features(config, cont_cols, "type1_features") + type5_cols = resolve_routing_features(config, cont_cols, "type5_features") + type4_cols = resolve_taxonomy_features(config, cont_cols, "type4_features") + model_cont_cols = [c for c in cont_cols if c not in type1_cols and c not in type5_cols] + if not model_cont_cols: + raise SystemExit("model_cont_cols is empty; check routing_type1_features/routing_type5_features") + + stats = load_json(config["stats_path"]) + mean = stats["mean"] + std = stats["std"] + transforms = stats.get("transform", {}) + raw_std = stats.get("raw_std", std) + quantile_probs = stats.get("quantile_probs") + quantile_values = stats.get("quantile_values") + use_quantile = bool(config.get("use_quantile_transform", False)) + + vocab = load_json(config["vocab_path"])["vocab"] + vocab_sizes = [len(vocab[c]) for c in disc_cols] + + data_paths = None + if "data_glob" in config and config["data_glob"]: + data_paths = sorted(Path(config["data_glob"]).parent.glob(Path(config["data_glob"]).name)) + if data_paths: + data_paths = [safe_path(p) for p in data_paths] + if not data_paths: + data_paths = [safe_path(config["data_path"])] + + use_condition = bool(config.get("use_condition")) and config.get("condition_type") == "file_id" + cond_vocab_size = len(data_paths) if use_condition else 0 + + device = resolve_device(str(config["device"])) + print("device", device) + model = HybridDiffusionModel( + cont_dim=len(model_cont_cols), + disc_vocab_sizes=vocab_sizes, + time_dim=int(config.get("model_time_dim", 64)), + hidden_dim=int(config.get("model_hidden_dim", 256)), + num_layers=int(config.get("model_num_layers", 1)), + dropout=float(config.get("model_dropout", 0.0)), + ff_mult=int(config.get("model_ff_mult", 2)), + pos_dim=int(config.get("model_pos_dim", 64)), + use_pos_embed=bool(config.get("model_use_pos_embed", True)), + backbone_type=str(config.get("backbone_type", "gru")), + transformer_num_layers=int(config.get("transformer_num_layers", 4)), + transformer_nhead=int(config.get("transformer_nhead", 8)), + transformer_ff_dim=int(config.get("transformer_ff_dim", 2048)), + transformer_dropout=float(config.get("transformer_dropout", 0.1)), + cond_cont_dim=len(type1_cols), + cond_vocab_size=cond_vocab_size, + cond_dim=int(config.get("cond_dim", 32)), + use_tanh_eps=bool(config.get("use_tanh_eps", False)), + eps_scale=float(config.get("eps_scale", 1.0)), + ).to(device) + opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"])) + + temporal_model = None + opt_temporal = None + temporal_use_type1_cond = bool(config.get("temporal_use_type1_cond", False)) + temporal_cond_dim = len(type1_cols) if (temporal_use_type1_cond and type1_cols) else 0 + temporal_focus_type4 = bool(config.get("temporal_focus_type4", False)) + temporal_exclude_type4 = bool(config.get("temporal_exclude_type4", False)) + type4_model_idx = [model_cont_cols.index(c) for c in type4_cols if c in model_cont_cols] + trend_mask = None + if temporal_focus_type4 and type4_model_idx: + trend_mask = torch.zeros(1, 1, len(model_cont_cols), device=device) + trend_mask[:, :, type4_model_idx] = 1.0 + elif temporal_exclude_type4 and type4_model_idx: + trend_mask = torch.ones(1, 1, len(model_cont_cols), device=device) + trend_mask[:, :, type4_model_idx] = 0.0 + if bool(config.get("use_temporal_stage1", False)): + temporal_model = build_temporal_model(config, model_cont_cols, temporal_cond_dim, device) + opt_temporal = torch.optim.Adam( + temporal_model.parameters(), + lr=float(config.get("temporal_lr", config["lr"])), + ) + ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None + + betas = cosine_beta_schedule(int(config["timesteps"])).to(device) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + + os.makedirs(config["out_dir"], exist_ok=True) + out_dir = Path(safe_path(config["out_dir"])) + log_path = out_dir / "train_log.csv" + init_or_append_log(log_path, args.resume) + + with open(out_dir / "config_used.json", "w", encoding="utf-8") as f: + json.dump(config, f, indent=2) + + main_ckpt_path = Path(args.resume_ckpt).resolve() if args.resume_ckpt else (out_dir / "model_ckpt.pt") + temporal_ckpt_path = out_dir / "temporal_ckpt.pt" + model_path = out_dir / "model.pt" + ema_path = out_dir / "model_ema.pt" + temporal_path = out_dir / "temporal.pt" + + temporal_start_epoch = 0 + temporal_start_step = 0 + temporal_total_step = 0 + main_start_epoch = 0 + main_start_step = 0 + total_step = 0 + temporal_done = temporal_model is None + + if args.resume: + if main_ckpt_path.exists(): + ckpt = load_torch_state(str(main_ckpt_path), device) + model.load_state_dict(ckpt["model"]) + opt.load_state_dict(ckpt["optim"]) + total_step = int(ckpt.get("step", 0)) + main_start_epoch = int(ckpt.get("epoch", 0)) + main_start_step = int(ckpt.get("step_in_epoch", 0)) + temporal_done = bool(ckpt.get("temporal_done", temporal_done)) + temporal_total_step = int(ckpt.get("temporal_step", 0)) + if ema is not None and ckpt.get("ema") is not None: + ema.shadow = ckpt["ema"] + if temporal_model is not None and ckpt.get("temporal") is not None: + temporal_model.load_state_dict(ckpt["temporal"]) + if opt_temporal is not None and ckpt.get("temporal_optim") is not None: + opt_temporal.load_state_dict(ckpt["temporal_optim"]) + print(f"resumed_main_ckpt epoch={main_start_epoch} step={main_start_step} total_step={total_step}") + elif temporal_ckpt_path.exists() and temporal_model is not None and opt_temporal is not None: + tckpt = load_torch_state(str(temporal_ckpt_path), device) + temporal_model.load_state_dict(tckpt["temporal"]) + opt_temporal.load_state_dict(tckpt["temporal_optim"]) + temporal_start_epoch = int(tckpt.get("epoch", 0)) + temporal_start_step = int(tckpt.get("step_in_epoch", 0)) + temporal_total_step = int(tckpt.get("temporal_step", 0)) + print( + f"resumed_temporal_ckpt epoch={temporal_start_epoch} " + f"step={temporal_start_step} temporal_step={temporal_total_step}" + ) + elif temporal_path.exists() and temporal_model is not None: + temporal_model.load_state_dict(load_torch_state(str(temporal_path), device)) + temporal_done = True + print("reused_completed_temporal_stage", str(temporal_path)) + + if temporal_model is not None and opt_temporal is not None and not temporal_done: + for epoch in range(temporal_start_epoch, int(config.get("temporal_epochs", 1))): + skip_until = temporal_start_step if epoch == temporal_start_epoch else 0 + for step, batch in enumerate( + windowed_batches( + data_paths, + cont_cols, + disc_cols, + vocab, + mean, + std, + batch_size=int(config["batch_size"]), + seq_len=int(config["seq_len"]), + max_batches=int(config["max_batches"]), + return_file_id=False, + transforms=transforms, + quantile_probs=quantile_probs, + quantile_values=quantile_values, + use_quantile=use_quantile, + shuffle_buffer=int(config.get("shuffle_buffer", 0)), + ) + ): + if step < skip_until: + continue + x_cont, _ = batch + x_cont = x_cont.to(device) + model_idx = [cont_cols.index(c) for c in model_cont_cols] + x_cont_model = x_cont[:, :, model_idx] + cond_cont = None + if temporal_cond_dim > 0: + cond_idx = [cont_cols.index(c) for c in type1_cols] + cond_cont = x_cont[:, :, cond_idx] + _, pred_next = temporal_model.forward_teacher(x_cont_model, cond_cont=cond_cont) + target_next = x_cont_model[:, 1:, :] + if trend_mask is not None: + mask = trend_mask.to(dtype=pred_next.dtype, device=pred_next.device) + mse = (pred_next - target_next) ** 2 + temporal_loss = (mse * mask).sum() / torch.clamp(mask.sum() * mse.size(0) * mse.size(1), min=1.0) + else: + temporal_loss = F.mse_loss(pred_next, target_next) + opt_temporal.zero_grad() + temporal_loss.backward() + if float(config.get("grad_clip", 0.0)) > 0: + torch.nn.utils.clip_grad_norm_(temporal_model.parameters(), float(config["grad_clip"])) + opt_temporal.step() + temporal_total_step += 1 + if step % int(config["log_every"]) == 0: + print("temporal_epoch", epoch, "step", step, "loss", float(temporal_loss)) + if temporal_total_step % int(config["ckpt_every"]) == 0: + atomic_torch_save( + { + "temporal": temporal_model.state_dict(), + "temporal_optim": opt_temporal.state_dict(), + "epoch": epoch, + "step_in_epoch": step + 1, + "temporal_step": temporal_total_step, + "config": config, + }, + temporal_ckpt_path, + ) + temporal_start_step = 0 + atomic_torch_save( + { + "temporal": temporal_model.state_dict(), + "temporal_optim": opt_temporal.state_dict(), + "epoch": epoch + 1, + "step_in_epoch": 0, + "temporal_step": temporal_total_step, + "config": config, + }, + temporal_ckpt_path, + ) + atomic_torch_save(temporal_model.state_dict(), temporal_path) + temporal_done = True + + if bool(args.temporal_only): + return + + for epoch in range(main_start_epoch, int(config["epochs"])): + skip_until = main_start_step if epoch == main_start_epoch else 0 + for step, batch in enumerate( + windowed_batches( + data_paths, + cont_cols, + disc_cols, + vocab, + mean, + std, + batch_size=int(config["batch_size"]), + seq_len=int(config["seq_len"]), + max_batches=int(config["max_batches"]), + return_file_id=use_condition, + transforms=transforms, + quantile_probs=quantile_probs, + quantile_values=quantile_values, + use_quantile=use_quantile, + shuffle_buffer=int(config.get("shuffle_buffer", 0)), + ) + ): + if step < skip_until: + continue + if use_condition: + x_cont, x_disc, cond = batch + cond = cond.to(device) + else: + x_cont, x_disc = batch + cond = None + x_cont = x_cont.to(device) + x_disc = x_disc.to(device) + + model_idx = [cont_cols.index(c) for c in model_cont_cols] + cond_idx = [cont_cols.index(c) for c in type1_cols] if type1_cols else [] + x_cont_model = x_cont[:, :, model_idx] + cond_cont = x_cont[:, :, cond_idx] if cond_idx else None + + trend = None + if temporal_model is not None: + temporal_model.eval() + with torch.no_grad(): + trend, _ = temporal_model.forward_teacher(x_cont_model, cond_cont=cond_cont) + if trend_mask is not None and trend is not None: + trend = trend * trend_mask.to(dtype=trend.dtype, device=trend.device) + x_cont_resid = x_cont_model if trend is None else x_cont_model - trend + + bsz = x_cont.size(0) + t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device) + + x_cont_t, noise = q_sample_continuous(x_cont_resid, t, alphas_cumprod) + + mask_tokens = torch.tensor(vocab_sizes, device=device) + x_disc_t, mask = q_sample_discrete( + x_disc, + t, + mask_tokens, + int(config["timesteps"]), + mask_scale=float(config.get("disc_mask_scale", 1.0)), + ) + + eps_pred, logits = model(x_cont_t, x_disc_t, t, cond, cond_cont=cond_cont) + + cont_target = str(config.get("cont_target", "eps")) + if cont_target == "x0": + x0_target = x_cont_resid + if float(config.get("cont_clamp_x0", 0.0)) > 0: + x0_target = torch.clamp( + x0_target, + -float(config["cont_clamp_x0"]), + float(config["cont_clamp_x0"]), + ) + loss_base = (eps_pred - x0_target) ** 2 + else: + loss_base = (eps_pred - noise) ** 2 + + if config.get("cont_loss_weighting") == "inv_std": + weights = torch.tensor( + [1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in model_cont_cols], + device=device, + dtype=eps_pred.dtype, + ).view(1, 1, -1) + loss_cont = (loss_base * weights).mean() + else: + loss_cont = loss_base.mean() + + if bool(config.get("snr_weighted_loss", False)): + a_bar_t = alphas_cumprod[t].view(-1, 1, 1) + snr = a_bar_t / torch.clamp(1.0 - a_bar_t, min=1e-8) + gamma = float(config.get("snr_gamma", 1.0)) + snr_weight = snr / (snr + gamma) + loss_cont = (loss_cont * snr_weight.mean()).mean() + loss_disc = 0.0 + loss_disc_count = 0 + for i, logit in enumerate(logits): + if mask[:, :, i].any(): + loss_disc = loss_disc + F.cross_entropy( + logit[mask[:, :, i]], + x_disc[:, :, i][mask[:, :, i]], + ) + loss_disc_count += 1 + if loss_disc_count > 0: + loss_disc = loss_disc / loss_disc_count + + lam = float(config["lambda"]) + loss = lam * loss_cont + (1 - lam) * loss_disc + + q_weight = float(config.get("quantile_loss_weight", 0.0)) + if q_weight > 0: + q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95]) + q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype) + a_bar_t = alphas_cumprod[t].view(-1, 1, 1) + x_real = x_cont_resid + if cont_target == "x0": + x_gen = eps_pred + else: + x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t) + x_real = x_real.view(-1, x_real.size(-1)) + x_gen = x_gen.view(-1, x_gen.size(-1)) + q_real = torch.quantile(x_real, q_tensor, dim=0) + q_gen = torch.quantile(x_gen, q_tensor, dim=0) + quantile_loss = torch.mean(torch.abs(q_gen - q_real)) + loss = loss + q_weight * quantile_loss + + stat_weight = float(config.get("residual_stat_weight", 0.0)) + if stat_weight > 0: + a_bar_t = alphas_cumprod[t].view(-1, 1, 1) + if cont_target == "x0": + x_gen = eps_pred + else: + x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t) + x_real = x_cont_resid + mean_real = x_real.mean(dim=(0, 1)) + mean_gen = x_gen.mean(dim=(0, 1)) + std_real = x_real.std(dim=(0, 1)) + std_gen = x_gen.std(dim=(0, 1)) + stat_loss = F.mse_loss(mean_gen, mean_real) + F.mse_loss(std_gen, std_real) + loss = loss + stat_weight * stat_loss + opt.zero_grad() + loss.backward() + if float(config.get("grad_clip", 0.0)) > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"])) + opt.step() + if ema is not None: + ema.update(model) + + if step % int(config["log_every"]) == 0: + print("epoch", epoch, "step", step, "loss", float(loss)) + with open(log_path, "a", encoding="utf-8") as f: + f.write( + "%d,%d,%.6f,%.6f,%.6f\n" + % (epoch, step, float(loss), float(loss_cont), float(loss_disc)) + ) + + total_step += 1 + if total_step % int(config["ckpt_every"]) == 0: + ckpt = { + "model": model.state_dict(), + "optim": opt.state_dict(), + "config": config, + "step": total_step, + "epoch": epoch, + "step_in_epoch": step + 1, + "temporal_done": temporal_done, + "temporal_step": temporal_total_step, + } + if ema is not None: + ckpt["ema"] = ema.state_dict() + if temporal_model is not None: + ckpt["temporal"] = temporal_model.state_dict() + if opt_temporal is not None: + ckpt["temporal_optim"] = opt_temporal.state_dict() + atomic_torch_save(ckpt, main_ckpt_path) + + main_start_step = 0 + ckpt = { + "model": model.state_dict(), + "optim": opt.state_dict(), + "config": config, + "step": total_step, + "epoch": epoch + 1, + "step_in_epoch": 0, + "temporal_done": temporal_done, + "temporal_step": temporal_total_step, + } + if ema is not None: + ckpt["ema"] = ema.state_dict() + if temporal_model is not None: + ckpt["temporal"] = temporal_model.state_dict() + if opt_temporal is not None: + ckpt["temporal_optim"] = opt_temporal.state_dict() + atomic_torch_save(ckpt, main_ckpt_path) + + atomic_torch_save(model.state_dict(), model_path) + if ema is not None: + atomic_torch_save(ema.state_dict(), ema_path) + if temporal_model is not None: + atomic_torch_save(temporal_model.state_dict(), temporal_path) + + +if __name__ == "__main__": + main()