Add resumable submission pipeline

This commit is contained in:
MZ YANG
2026-04-18 19:01:25 +08:00
parent b8696d0c54
commit fb3f82006a
6 changed files with 1526 additions and 0 deletions

View File

@@ -0,0 +1,81 @@
{
"data_path": "../../dataset/hai/hai-21.03/train1.csv.gz",
"data_glob": "../../dataset/hai/hai-21.03/train*.csv.gz",
"split_path": "./feature_split.json",
"stats_path": "./results/cont_stats.json",
"vocab_path": "./results/disc_vocab.json",
"out_dir": "./results",
"device": "auto",
"timesteps": 600,
"batch_size": 12,
"seq_len": 96,
"epochs": 10,
"max_batches": 4000,
"lambda": 0.7,
"lr": 0.0005,
"seed": 1337,
"log_every": 10,
"ckpt_every": 50,
"ema_decay": 0.999,
"use_ema": true,
"clip_k": 5.0,
"grad_clip": 1.0,
"use_condition": true,
"condition_type": "file_id",
"cond_dim": 32,
"use_tanh_eps": false,
"eps_scale": 1.0,
"model_time_dim": 128,
"model_hidden_dim": 512,
"model_num_layers": 2,
"model_dropout": 0.1,
"model_ff_mult": 2,
"model_pos_dim": 64,
"model_use_pos_embed": true,
"backbone_type": "transformer",
"transformer_num_layers": 3,
"transformer_nhead": 4,
"transformer_ff_dim": 512,
"transformer_dropout": 0.1,
"disc_mask_scale": 0.9,
"cont_loss_weighting": "inv_std",
"cont_loss_eps": 1e-6,
"cont_target": "x0",
"cont_clamp_x0": 5.0,
"use_quantile_transform": true,
"quantile_bins": 1001,
"cont_bound_mode": "none",
"cont_bound_strength": 2.0,
"cont_post_calibrate": true,
"cont_post_scale": {},
"full_stats": true,
"type1_features": ["P1_B4002", "P2_MSD", "P4_HT_LD", "P1_B2004", "P1_B3004", "P1_B4022", "P1_B3005"],
"type2_features": ["P1_B4005"],
"type3_features": ["P1_PCV02Z", "P1_PCV01Z", "P1_PCV01D", "P1_FCV02Z", "P1_FCV03D", "P1_FCV03Z", "P1_LCV01D", "P1_LCV01Z"],
"type4_features": ["P1_PIT02", "P2_SIT02", "P1_FT03"],
"type5_features": ["P1_FT03Z", "P1_FT02Z"],
"type6_features": ["P4_HT_PO", "P2_24Vdc", "P2_HILout"],
"routing_type1_features": ["P1_B4022"],
"routing_type5_features": [],
"shuffle_buffer": 256,
"use_temporal_stage1": true,
"temporal_backbone": "transformer",
"temporal_hidden_dim": 256,
"temporal_num_layers": 1,
"temporal_dropout": 0.0,
"temporal_pos_dim": 64,
"temporal_use_pos_embed": true,
"temporal_transformer_num_layers": 2,
"temporal_transformer_nhead": 4,
"temporal_transformer_ff_dim": 256,
"temporal_transformer_dropout": 0.1,
"temporal_epochs": 3,
"temporal_lr": 0.001,
"quantile_loss_weight": 0.2,
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
"snr_weighted_loss": true,
"snr_gamma": 1.0,
"residual_stat_weight": 0.05,
"sample_batch_size": 4,
"sample_seq_len": 96
}

View File

@@ -0,0 +1,440 @@
#!/usr/bin/env python3
"""Sample from a trained hybrid diffusion model with routing-aware export fixes."""
from __future__ import annotations
import argparse
import csv
import gzip
import json
import os
from pathlib import Path
import torch
import torch.nn.functional as F
from data_utils import inverse_quantile_transform, load_split, normalize_cont, quantile_calibrate_to_real
from export_samples import build_inverse_vocab, load_stats, load_torch_state, read_header
from hybrid_diffusion import (
HybridDiffusionModel,
TemporalGRUGenerator,
TemporalTransformerGenerator,
cosine_beta_schedule,
)
from platform_utils import resolve_device, resolve_path
from submission_type_utils import (
denormalize_cont_tensor,
resolve_routing_features,
resolve_taxonomy_features,
)
def parse_args():
parser = argparse.ArgumentParser(description="Sample and export HAI feature sequences with routing-aware fixes.")
base_dir = Path(__file__).resolve().parent
repo_dir = base_dir.parent.parent
parser.add_argument("--data-path", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"))
parser.add_argument("--data-glob", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"))
parser.add_argument("--split-path", default=str(base_dir / "feature_split.json"))
parser.add_argument("--stats-path", default=str(base_dir / "results" / "cont_stats.json"))
parser.add_argument("--vocab-path", default=str(base_dir / "results" / "disc_vocab.json"))
parser.add_argument("--model-path", default=str(base_dir / "results" / "model.pt"))
parser.add_argument("--out", default=str(base_dir / "results" / "generated.csv"))
parser.add_argument("--timesteps", type=int, default=200)
parser.add_argument("--seq-len", type=int, default=64)
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
parser.add_argument("--include-time", action="store_true", help="Include time column as a simple index")
parser.add_argument("--clip-k", type=float, default=5.0, help="Clip continuous values to mean±k*std")
parser.add_argument("--use-ema", action="store_true", help="Use EMA weights if available")
parser.add_argument("--config", default=None, help="Optional config_used.json to infer conditioning")
parser.add_argument("--condition-id", type=int, default=-1, help="Condition file id (0..N-1), -1=random")
parser.add_argument("--include-condition", action="store_true", help="Include condition id column in CSV")
return parser.parse_args()
def main():
args = parse_args()
base_dir = Path(__file__).resolve().parent
args.data_path = str(resolve_path(base_dir, args.data_path))
args.data_glob = str(resolve_path(base_dir, args.data_glob)) if args.data_glob else ""
args.split_path = str(resolve_path(base_dir, args.split_path))
args.stats_path = str(resolve_path(base_dir, args.stats_path))
args.vocab_path = str(resolve_path(base_dir, args.vocab_path))
args.model_path = str(resolve_path(base_dir, args.model_path))
args.out = str(resolve_path(base_dir, args.out))
if not os.path.exists(args.model_path):
raise SystemExit(f"missing model file: {args.model_path}")
data_path = args.data_path
if args.data_glob:
base = Path(args.data_glob).parent
pat = Path(args.data_glob).name
matches = sorted(base.glob(pat))
if matches:
data_path = str(matches[0])
split = load_split(args.split_path)
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
stats = load_stats(args.stats_path)
mean = stats["mean"]
std = stats["std"]
vmin = stats.get("min", {})
vmax = stats.get("max", {})
int_like = stats.get("int_like", {})
max_decimals = stats.get("max_decimals", {})
transforms = stats.get("transform", {})
quantile_probs = stats.get("quantile_probs")
quantile_values = stats.get("quantile_values")
quantile_raw_values = stats.get("quantile_raw_values")
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
vocab = vocab_json["vocab"]
top_token = vocab_json.get("top_token", {})
inv_vocab = build_inverse_vocab(vocab)
vocab_sizes = [len(vocab[c]) for c in disc_cols]
device = resolve_device(args.device)
cfg = {}
use_condition = False
cond_vocab_size = 0
if args.config:
args.config = str(resolve_path(base_dir, args.config))
if args.config and os.path.exists(args.config):
with open(args.config, "r", encoding="utf-8") as f:
cfg = json.load(f)
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
if use_condition:
cfg_base = Path(args.config).resolve().parent
cfg_glob = cfg.get("data_glob", args.data_glob)
cfg_glob = str(resolve_path(cfg_base, cfg_glob))
base = Path(cfg_glob).parent
pat = Path(cfg_glob).name
cond_vocab_size = len(sorted(base.glob(pat)))
if cond_vocab_size <= 0:
raise SystemExit(f"use_condition enabled but no files matched data_glob: {cfg_glob}")
cont_target = str(cfg.get("cont_target", "eps"))
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
use_quantile = bool(cfg.get("use_quantile_transform", False))
cont_bound_mode = str(cfg.get("cont_bound_mode", "clamp"))
cont_bound_strength = float(cfg.get("cont_bound_strength", 1.0))
cont_post_scale = cfg.get("cont_post_scale", {}) if isinstance(cfg.get("cont_post_scale", {}), dict) else {}
cont_post_calibrate = bool(cfg.get("cont_post_calibrate", False))
route_type1_cols = resolve_routing_features(cfg, cont_cols, "type1_features")
route_type5_cols = resolve_routing_features(cfg, cont_cols, "type5_features")
type4_cols = resolve_taxonomy_features(cfg, cont_cols, "type4_features")
model_cont_cols = [c for c in cont_cols if c not in route_type1_cols and c not in route_type5_cols]
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
temporal_use_type1_cond = bool(cfg.get("temporal_use_type1_cond", False))
temporal_focus_type4 = bool(cfg.get("temporal_focus_type4", False))
temporal_exclude_type4 = bool(cfg.get("temporal_exclude_type4", False))
temporal_backbone = str(cfg.get("temporal_backbone", "gru"))
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
temporal_pos_dim = int(cfg.get("temporal_pos_dim", 64))
temporal_use_pos_embed = bool(cfg.get("temporal_use_pos_embed", True))
temporal_transformer_num_layers = int(cfg.get("temporal_transformer_num_layers", 2))
temporal_transformer_nhead = int(cfg.get("temporal_transformer_nhead", 4))
temporal_transformer_ff_dim = int(cfg.get("temporal_transformer_ff_dim", 512))
temporal_transformer_dropout = float(cfg.get("temporal_transformer_dropout", 0.1))
backbone_type = str(cfg.get("backbone_type", "gru"))
transformer_num_layers = int(cfg.get("transformer_num_layers", 2))
transformer_nhead = int(cfg.get("transformer_nhead", 4))
transformer_ff_dim = int(cfg.get("transformer_ff_dim", 512))
transformer_dropout = float(cfg.get("transformer_dropout", 0.1))
model = HybridDiffusionModel(
cont_dim=len(model_cont_cols),
disc_vocab_sizes=vocab_sizes,
time_dim=int(cfg.get("model_time_dim", 64)),
hidden_dim=int(cfg.get("model_hidden_dim", 256)),
num_layers=int(cfg.get("model_num_layers", 1)),
dropout=float(cfg.get("model_dropout", 0.0)),
ff_mult=int(cfg.get("model_ff_mult", 2)),
pos_dim=int(cfg.get("model_pos_dim", 64)),
use_pos_embed=bool(cfg.get("model_use_pos_embed", True)),
backbone_type=backbone_type,
transformer_num_layers=transformer_num_layers,
transformer_nhead=transformer_nhead,
transformer_ff_dim=transformer_ff_dim,
transformer_dropout=transformer_dropout,
cond_cont_dim=len(route_type1_cols),
cond_vocab_size=cond_vocab_size if use_condition else 0,
cond_dim=int(cfg.get("cond_dim", 32)),
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
eps_scale=float(cfg.get("eps_scale", 1.0)),
).to(device)
if args.use_ema and os.path.exists(args.model_path.replace("model.pt", "model_ema.pt")):
ema_path = args.model_path.replace("model.pt", "model_ema.pt")
model.load_state_dict(load_torch_state(ema_path, device))
else:
model.load_state_dict(load_torch_state(args.model_path, device))
model.eval()
temporal_model = None
if use_temporal_stage1:
temporal_path = Path(args.model_path).with_name("temporal.pt")
if not temporal_path.exists():
raise SystemExit(f"missing temporal model file: {temporal_path}")
temporal_state = load_torch_state(str(temporal_path), device)
temporal_cond_dim = len(route_type1_cols) if (temporal_use_type1_cond and route_type1_cols) else 0
if isinstance(temporal_state, dict):
if "in_proj.weight" in temporal_state:
try:
temporal_cond_dim = max(0, int(temporal_state["in_proj.weight"].shape[1]) - len(model_cont_cols))
except Exception:
pass
elif "gru.weight_ih_l0" in temporal_state:
try:
temporal_cond_dim = max(0, int(temporal_state["gru.weight_ih_l0"].shape[1]) - len(model_cont_cols))
except Exception:
pass
if temporal_backbone == "transformer":
temporal_model = TemporalTransformerGenerator(
input_dim=len(model_cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_transformer_num_layers,
nhead=temporal_transformer_nhead,
ff_dim=temporal_transformer_ff_dim,
dropout=temporal_transformer_dropout,
pos_dim=temporal_pos_dim,
use_pos_embed=temporal_use_pos_embed,
cond_dim=temporal_cond_dim,
).to(device)
else:
temporal_model = TemporalGRUGenerator(
input_dim=len(model_cont_cols),
hidden_dim=temporal_hidden_dim,
num_layers=temporal_num_layers,
dropout=temporal_dropout,
cond_dim=temporal_cond_dim,
).to(device)
temporal_model.load_state_dict(temporal_state)
temporal_model.eval()
betas = cosine_beta_schedule(args.timesteps).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
x_cont = torch.randn(args.batch_size, args.seq_len, len(model_cont_cols), device=device)
x_disc = torch.full((args.batch_size, args.seq_len, len(disc_cols)), 0, device=device, dtype=torch.long)
mask_tokens = torch.tensor(vocab_sizes, device=device)
for i in range(len(disc_cols)):
x_disc[:, :, i] = mask_tokens[i]
cond = None
if use_condition:
if cond_vocab_size <= 0:
raise SystemExit("use_condition enabled but no files matched data_glob")
if args.condition_id < 0:
cond_id = torch.randint(0, cond_vocab_size, (args.batch_size,), device=device)
else:
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
cond = cond_id
cond_cont = None
if route_type1_cols:
ref_glob = cfg.get("data_glob") or args.data_glob
if ref_glob:
ref_glob = str(resolve_path(Path(args.config).parent, ref_glob)) if args.config else ref_glob
base = Path(ref_glob).parent
pat = Path(ref_glob).name
refs = sorted(base.glob(pat))
if refs:
ref_path = refs[0]
ref_rows = []
with gzip.open(ref_path, "rt", newline="") as fh:
reader = csv.DictReader(fh)
for row in reader:
ref_rows.append(row)
if len(ref_rows) >= args.seq_len:
seq = ref_rows[: args.seq_len]
cond_cont = torch.zeros(args.batch_size, args.seq_len, len(route_type1_cols), device=device)
for t, row in enumerate(seq):
for i, c in enumerate(route_type1_cols):
cond_cont[:, t, i] = float(row[c])
cond_cont = normalize_cont(
cond_cont,
route_type1_cols,
mean,
std,
transforms=transforms,
quantile_probs=quantile_probs,
quantile_values=quantile_values,
use_quantile=use_quantile,
)
trend = None
if temporal_model is not None:
trend = temporal_model.generate(args.batch_size, args.seq_len, device, cond_cont=cond_cont)
if temporal_focus_type4 and type4_cols:
type4_model_idx = [model_cont_cols.index(c) for c in type4_cols if c in model_cont_cols]
if type4_model_idx:
trend_mask = torch.zeros(1, 1, len(model_cont_cols), device=device, dtype=trend.dtype)
trend_mask[:, :, type4_model_idx] = 1.0
trend = trend * trend_mask
elif temporal_exclude_type4 and type4_cols:
type4_model_idx = [model_cont_cols.index(c) for c in type4_cols if c in model_cont_cols]
if type4_model_idx:
trend_mask = torch.ones(1, 1, len(model_cont_cols), device=device, dtype=trend.dtype)
trend_mask[:, :, type4_model_idx] = 0.0
trend = trend * trend_mask
for t in reversed(range(args.timesteps)):
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
eps_pred, logits = model(x_cont, x_disc, t_batch, cond, cond_cont=cond_cont)
a_t = alphas[t]
a_bar_t = alphas_cumprod[t]
if cont_target == "x0":
x0_pred = eps_pred
if cont_clamp_x0 > 0:
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
coef1 = 1.0 / torch.sqrt(a_t)
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
mean_x = coef1 * (x_cont - coef2 * eps_pred)
if t > 0:
noise = torch.randn_like(x_cont)
x_cont = mean_x + torch.sqrt(betas[t]) * noise
else:
x_cont = mean_x
if args.clip_k > 0:
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
for i, logit in enumerate(logits):
if t == 0:
probs = F.softmax(logit, dim=-1)
x_disc[:, :, i] = torch.argmax(probs, dim=-1)
else:
mask = x_disc[:, :, i] == mask_tokens[i]
if mask.any():
probs = F.softmax(logit, dim=-1)
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(
args.batch_size, args.seq_len
)
x_disc[:, :, i][mask] = sampled[mask]
if trend is not None:
x_cont = x_cont + trend
x_cont = x_cont.cpu()
x_disc = x_disc.cpu()
if args.clip_k > 0:
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
if use_quantile:
q_vals = {c: quantile_values[c] for c in model_cont_cols}
x_cont = inverse_quantile_transform(x_cont, model_cont_cols, quantile_probs, q_vals)
else:
mean_vec = torch.tensor([mean[c] for c in model_cont_cols], dtype=x_cont.dtype)
std_vec = torch.tensor([std[c] for c in model_cont_cols], dtype=x_cont.dtype)
x_cont = x_cont * std_vec + mean_vec
for i, c in enumerate(model_cont_cols):
if transforms.get(c) == "log1p":
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])
if cont_post_calibrate and quantile_raw_values and quantile_probs:
q_raw = {c: quantile_raw_values[c] for c in model_cont_cols}
x_cont = quantile_calibrate_to_real(x_cont, model_cont_cols, quantile_probs, q_raw)
if vmin and vmax:
for i, c in enumerate(model_cont_cols):
lo = vmin.get(c, None)
hi = vmax.get(c, None)
if lo is None or hi is None:
continue
lo = float(lo)
hi = float(hi)
if cont_bound_mode == "none":
continue
if cont_bound_mode == "sigmoid":
x_cont[:, :, i] = lo + (hi - lo) * torch.sigmoid(x_cont[:, :, i])
elif cont_bound_mode == "soft_tanh":
mid = 0.5 * (lo + hi)
half = 0.5 * (hi - lo)
denom = cont_bound_strength if cont_bound_strength > 0 else 1.0
x_cont[:, :, i] = mid + half * torch.tanh(x_cont[:, :, i] / denom)
else:
x_cont[:, :, i] = torch.clamp(x_cont[:, :, i], lo, hi)
if cont_post_scale:
for i, c in enumerate(model_cont_cols):
if c in cont_post_scale:
try:
scale = float(cont_post_scale[c])
except Exception:
scale = 1.0
x_cont[:, :, i] = x_cont[:, :, i] * scale
full_cont = torch.zeros(args.batch_size, args.seq_len, len(cont_cols), dtype=x_cont.dtype)
for i, c in enumerate(model_cont_cols):
full_idx = cont_cols.index(c)
full_cont[:, :, full_idx] = x_cont[:, :, i]
if cond_cont is not None and route_type1_cols:
cond_denorm = denormalize_cont_tensor(
cond_cont.cpu(),
route_type1_cols,
mean,
std,
transforms=transforms,
quantile_probs=quantile_probs,
quantile_values=quantile_values,
use_quantile=use_quantile,
)
for i, c in enumerate(route_type1_cols):
full_idx = cont_cols.index(c)
full_cont[:, :, full_idx] = cond_denorm[:, :, i]
for c in route_type5_cols:
if c.endswith("Z"):
base = c[:-1]
if base in cont_cols:
bidx = cont_cols.index(base)
cidx = cont_cols.index(c)
full_cont[:, :, cidx] = full_cont[:, :, bidx]
header = read_header(data_path)
out_cols = [c for c in header if c != time_col or args.include_time]
if args.include_condition and use_condition:
out_cols = ["__cond_file_id"] + out_cols
os.makedirs(os.path.dirname(args.out), exist_ok=True)
with open(args.out, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=out_cols)
writer.writeheader()
row_index = 0
for b in range(args.batch_size):
for t in range(args.seq_len):
row = {}
if args.include_condition and use_condition:
row["__cond_file_id"] = str(int(cond[b].item())) if cond is not None else "-1"
if args.include_time and time_col in header:
row[time_col] = str(row_index)
for i, c in enumerate(cont_cols):
val = float(full_cont[b, t, i])
if int_like.get(c, False):
row[c] = str(int(round(val)))
else:
dec = int(max_decimals.get(c, 6))
fmt = ("%%.%df" % dec) if dec > 0 else "%.0f"
row[c] = fmt % val
for i, c in enumerate(disc_cols):
tok_idx = int(x_disc[b, t, i])
tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "<UNK>"
if tok == "<UNK>" and c in top_token:
tok = top_token[c]
row[c] = tok
writer.writerow(row)
row_index += 1
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,19 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
RUN_DIR="${RUN_DIR:-$SCRIPT_DIR/results/submission_full}"
LOG_DIR="$RUN_DIR/logs"
mkdir -p "$LOG_DIR"
STAMP="$(date '+%Y%m%d-%H%M%S')"
LOG_FILE="$LOG_DIR/pipeline-$STAMP.log"
echo "[run_submission_full] run_dir=$RUN_DIR"
echo "[run_submission_full] log_file=$LOG_FILE"
python "$SCRIPT_DIR/run_submission_resume.py" \
--config "$SCRIPT_DIR/config_submission_full.json" \
--device "${DEVICE:-cuda}" \
--run-dir "$RUN_DIR" \
"$@" 2>&1 | tee -a "$LOG_FILE"

View File

@@ -0,0 +1,399 @@
#!/usr/bin/env python3
"""One-command full pipeline runner with safe resume and stage skipping."""
from __future__ import annotations
import argparse
import json
import subprocess
import sys
from pathlib import Path
from typing import Dict, List
from platform_utils import is_windows, safe_path
def run(cmd: List[str]) -> None:
print("running:", " ".join(cmd))
cmd = [safe_path(arg) for arg in cmd]
if is_windows():
subprocess.run(cmd, check=True, shell=False)
else:
subprocess.run(cmd, check=True)
def parse_args():
base_dir = Path(__file__).resolve().parent
parser = argparse.ArgumentParser(description="Run prepare -> train -> export -> eval with resume-aware staging.")
parser.add_argument("--config", default=str(base_dir / "config_submission_full.json"))
parser.add_argument("--device", default="auto")
parser.add_argument("--run-dir", default=str(base_dir / "results" / "submission_full"))
parser.add_argument("--reference", default="")
parser.add_argument("--no-resume", action="store_true", help="Do not auto-skip completed stages or resume from ckpt.")
parser.add_argument("--skip-prepare", action="store_true")
parser.add_argument("--skip-train", action="store_true")
parser.add_argument("--skip-export", action="store_true")
parser.add_argument("--skip-eval", action="store_true")
parser.add_argument("--skip-comprehensive-eval", action="store_true")
parser.add_argument("--skip-postprocess", action="store_true")
parser.add_argument("--skip-post-eval", action="store_true")
parser.add_argument("--skip-diagnostics", action="store_true")
return parser.parse_args()
def load_state(path: Path) -> Dict[str, str]:
if not path.exists():
return {}
try:
return json.loads(path.read_text(encoding="utf-8"))
except Exception:
return {}
def save_state(path: Path, state: Dict[str, str]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(state, indent=2, sort_keys=True), encoding="utf-8")
def stage_complete(state: Dict[str, str], stage: str, outputs: List[Path], resume: bool) -> bool:
if not resume:
return False
if outputs and all(p.exists() for p in outputs):
return True
return state.get(stage) == "done"
def main():
args = parse_args()
base_dir = Path(__file__).resolve().parent
config_path = Path(args.config)
if not config_path.is_absolute():
config_path = (base_dir / config_path).resolve()
run_dir = Path(args.run_dir)
if not run_dir.is_absolute():
run_dir = (base_dir / run_dir).resolve()
run_dir.mkdir(parents=True, exist_ok=True)
cfg = json.loads(config_path.read_text(encoding="utf-8"))
cfg_base = config_path.parent
def abs_cfg_like(value: str) -> str:
p = Path(value)
if p.is_absolute():
return str(p)
if any(ch in value for ch in ["*", "?", "["]):
return str(cfg_base / p)
return str((cfg_base / p).resolve())
ref = args.reference or cfg.get("data_glob") or cfg.get("data_path") or ""
if ref:
ref = abs_cfg_like(str(ref))
timesteps = int(cfg.get("timesteps", 200))
seq_len = int(cfg.get("sample_seq_len", cfg.get("seq_len", 64)))
batch_size = int(cfg.get("sample_batch_size", cfg.get("batch_size", 2)))
clip_k = float(cfg.get("clip_k", 5.0))
split_path = abs_cfg_like(str(cfg.get("split_path", "./feature_split.json")))
stats_path = abs_cfg_like(str(cfg.get("stats_path", "./results/cont_stats.json")))
vocab_path = abs_cfg_like(str(cfg.get("vocab_path", "./results/disc_vocab.json")))
data_path = abs_cfg_like(str(cfg.get("data_path", ""))) if cfg.get("data_path") else ""
data_glob = abs_cfg_like(str(cfg.get("data_glob", ""))) if cfg.get("data_glob") else ""
state_path = run_dir / "pipeline_state.json"
state = load_state(state_path)
resume = not args.no_resume
cfg_for_steps = run_dir / "config_used.json"
stage_defs = []
if not args.skip_prepare:
stage_defs.append(
(
"prepare",
[Path(stats_path), Path(vocab_path)],
[sys.executable, str(base_dir / "prepare_data.py"), "--config", str(config_path)],
)
)
if not args.skip_train:
train_cmd = [
sys.executable,
str(base_dir / "train_resume.py"),
"--config",
str(config_path),
"--device",
args.device,
"--out-dir",
str(run_dir),
"--seed",
str(int(cfg.get("seed", 1337))),
]
if resume:
train_cmd.append("--resume")
stage_defs.append(("train", [run_dir / "model.pt"], train_cmd))
if not args.skip_export:
stage_defs.append(
(
"export",
[run_dir / "generated.csv"],
[
sys.executable,
str(base_dir / "export_samples_resume.py"),
"--include-time",
"--device",
args.device,
"--config",
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
"--data-path",
str(data_path),
"--data-glob",
str(data_glob),
"--split-path",
str(split_path),
"--stats-path",
str(stats_path),
"--vocab-path",
str(vocab_path),
"--model-path",
str(run_dir / "model.pt"),
"--out",
str(run_dir / "generated.csv"),
"--timesteps",
str(timesteps),
"--seq-len",
str(seq_len),
"--batch-size",
str(batch_size),
"--clip-k",
str(clip_k),
"--use-ema",
],
)
)
if not args.skip_eval:
eval_cmd = [
sys.executable,
str(base_dir / "evaluate_generated.py"),
"--generated",
str(run_dir / "generated.csv"),
"--split",
str(split_path),
"--stats",
str(stats_path),
"--vocab",
str(vocab_path),
"--out",
str(run_dir / "eval.json"),
]
if ref:
eval_cmd += ["--reference", str(ref)]
stage_defs.append(("eval", [run_dir / "eval.json"], eval_cmd))
if not args.skip_comprehensive_eval:
stage_defs.append(
(
"comprehensive_eval",
[run_dir / "comprehensive_eval.json"],
[
sys.executable,
str(base_dir / "evaluate_comprehensive.py"),
"--generated",
str(run_dir / "generated.csv"),
"--reference",
str(config_path),
"--config",
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
"--split",
str(split_path),
"--stats",
str(stats_path),
"--vocab",
str(vocab_path),
"--out",
str(run_dir / "comprehensive_eval.json"),
"--device",
args.device,
],
)
)
if not args.skip_postprocess:
post_cmd = [
sys.executable,
str(base_dir / "postprocess_types.py"),
"--generated",
str(run_dir / "generated.csv"),
"--config",
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
"--out",
str(run_dir / "generated_post.csv"),
"--seed",
str(int(cfg.get("seed", 1337))),
]
if ref:
post_cmd += ["--reference", str(ref)]
stage_defs.append(("postprocess", [run_dir / "generated_post.csv"], post_cmd))
if not args.skip_post_eval:
post_eval_cmd = [
sys.executable,
str(base_dir / "evaluate_generated.py"),
"--generated",
str(run_dir / "generated_post.csv"),
"--split",
str(split_path),
"--stats",
str(stats_path),
"--vocab",
str(vocab_path),
"--out",
str(run_dir / "eval_post.json"),
]
if ref:
post_eval_cmd += ["--reference", str(ref)]
stage_defs.append(("post_eval", [run_dir / "eval_post.json"], post_eval_cmd))
if not args.skip_comprehensive_eval:
stage_defs.append(
(
"comprehensive_post_eval",
[run_dir / "comprehensive_eval_post.json"],
[
sys.executable,
str(base_dir / "evaluate_comprehensive.py"),
"--generated",
str(run_dir / "generated_post.csv"),
"--reference",
str(config_path),
"--config",
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
"--split",
str(split_path),
"--stats",
str(stats_path),
"--vocab",
str(vocab_path),
"--out",
str(run_dir / "comprehensive_eval_post.json"),
"--device",
args.device,
],
)
)
if not args.skip_diagnostics:
stage_defs.extend(
[
(
"filtered_metrics",
[run_dir / "filtered_metrics.json"],
[
sys.executable,
str(base_dir / "filtered_metrics.py"),
"--eval",
str(run_dir / "eval.json"),
"--out",
str(run_dir / "filtered_metrics.json"),
],
),
(
"ranked_ks",
[run_dir / "ranked_ks.csv"],
[
sys.executable,
str(base_dir / "ranked_ks.py"),
"--eval",
str(run_dir / "eval.json"),
"--out",
str(run_dir / "ranked_ks.csv"),
],
),
(
"program_stats",
[run_dir / "program_stats.json"],
[
sys.executable,
str(base_dir / "program_stats.py"),
"--generated",
str(run_dir / "generated.csv"),
"--reference",
str(config_path),
"--config",
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
],
),
(
"controller_stats",
[run_dir / "controller_stats.json"],
[
sys.executable,
str(base_dir / "controller_stats.py"),
"--generated",
str(run_dir / "generated.csv"),
"--reference",
str(config_path),
"--config",
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
],
),
(
"actuator_stats",
[run_dir / "actuator_stats.json"],
[
sys.executable,
str(base_dir / "actuator_stats.py"),
"--generated",
str(run_dir / "generated.csv"),
"--reference",
str(config_path),
"--config",
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
],
),
(
"pv_stats",
[run_dir / "pv_stats.json"],
[
sys.executable,
str(base_dir / "pv_stats.py"),
"--generated",
str(run_dir / "generated.csv"),
"--reference",
str(config_path),
"--config",
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
],
),
(
"aux_stats",
[run_dir / "aux_stats.json"],
[
sys.executable,
str(base_dir / "aux_stats.py"),
"--generated",
str(run_dir / "generated.csv"),
"--reference",
str(config_path),
"--config",
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
],
),
]
)
command_log = run_dir / "run_commands.txt"
if not command_log.exists():
command_log.write_text("", encoding="utf-8")
for stage, outputs, cmd in stage_defs:
if stage_complete(state, stage, outputs, resume):
print(f"skip_stage {stage}: outputs already present")
state[stage] = "done"
save_state(state_path, state)
continue
state[stage] = "running"
save_state(state_path, state)
with command_log.open("a", encoding="utf-8") as fh:
fh.write(stage + ": " + " ".join(cmd) + "\n")
run(cmd)
state[stage] = "done"
save_state(state_path, state)
print(f"pipeline_complete run_dir={run_dir}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,53 @@
#!/usr/bin/env python3
"""Helpers for keeping type taxonomy and routing policy separate."""
from __future__ import annotations
from typing import Dict, List, Optional
import torch
from data_utils import inverse_quantile_transform
def resolve_taxonomy_features(config: Dict, cont_cols: List[str], base_key: str) -> List[str]:
feats = config.get(base_key, []) or []
return [c for c in feats if c in cont_cols]
def resolve_routing_features(config: Dict, cont_cols: List[str], base_key: str) -> List[str]:
feats = config.get(f"routing_{base_key}", config.get(base_key, [])) or []
return [c for c in feats if c in cont_cols]
def denormalize_cont_tensor(
x: torch.Tensor,
cont_cols: List[str],
mean: Dict[str, float],
std: Dict[str, float],
transforms: Optional[Dict[str, str]] = None,
quantile_probs: Optional[List[float]] = None,
quantile_values: Optional[Dict[str, List[float]]] = None,
use_quantile: bool = False,
) -> torch.Tensor:
if x is None:
raise ValueError("x must not be None")
if not cont_cols:
return x.clone()
out = x.clone()
if use_quantile:
if not quantile_probs or not quantile_values:
raise ValueError("use_quantile=True but quantile stats are missing")
q_vals = {c: quantile_values[c] for c in cont_cols}
out = inverse_quantile_transform(out, cont_cols, quantile_probs, q_vals)
else:
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=out.dtype, device=out.device)
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=out.dtype, device=out.device)
out = out * std_vec + mean_vec
if transforms:
for i, c in enumerate(cont_cols):
if transforms.get(c) == "log1p":
out[:, :, i] = torch.expm1(out[:, :, i])
return out

534
example/train_resume.py Normal file
View File

@@ -0,0 +1,534 @@
#!/usr/bin/env python3
"""Train hybrid diffusion with checkpoint resume and selective type-aware routing."""
from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
from typing import Dict
import torch
import torch.nn.functional as F
from data_utils import load_split, windowed_batches
from hybrid_diffusion import (
HybridDiffusionModel,
TemporalGRUGenerator,
TemporalTransformerGenerator,
cosine_beta_schedule,
q_sample_continuous,
q_sample_discrete,
)
from platform_utils import resolve_device, resolve_path, safe_path
from submission_type_utils import resolve_routing_features, resolve_taxonomy_features
from train import DEFAULTS, EMA, load_json, resolve_config_paths, set_seed
BASE_DIR = Path(__file__).resolve().parent
def load_torch_state(path: str, device: str):
try:
return torch.load(path, map_location=device, weights_only=True)
except TypeError:
return torch.load(path, map_location=device)
def atomic_torch_save(obj, path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
tmp = path.with_suffix(path.suffix + ".tmp")
torch.save(obj, str(tmp))
os.replace(str(tmp), str(path))
def parse_args():
parser = argparse.ArgumentParser(description="Train hybrid diffusion on HAI with resume support.")
parser.add_argument("--config", default=None, help="Path to JSON config.")
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
parser.add_argument("--out-dir", default=None, help="Override output directory")
parser.add_argument("--seed", type=int, default=None, help="Override random seed")
parser.add_argument("--temporal-only", action="store_true", help="Only train temporal stage-1 and exit.")
parser.add_argument("--resume", action="store_true", help="Resume from checkpoint in out-dir if present.")
parser.add_argument("--resume-ckpt", default=None, help="Optional explicit model checkpoint path.")
return parser.parse_args()
def build_temporal_model(config: Dict, model_cont_cols, temporal_cond_dim: int, device: str):
temporal_backbone = str(config.get("temporal_backbone", "gru"))
if temporal_backbone == "transformer":
return TemporalTransformerGenerator(
input_dim=len(model_cont_cols),
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
num_layers=int(config.get("temporal_transformer_num_layers", 2)),
nhead=int(config.get("temporal_transformer_nhead", 4)),
ff_dim=int(config.get("temporal_transformer_ff_dim", 512)),
dropout=float(config.get("temporal_transformer_dropout", 0.1)),
pos_dim=int(config.get("temporal_pos_dim", 64)),
use_pos_embed=bool(config.get("temporal_use_pos_embed", True)),
cond_dim=temporal_cond_dim,
).to(device)
return TemporalGRUGenerator(
input_dim=len(model_cont_cols),
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
num_layers=int(config.get("temporal_num_layers", 1)),
dropout=float(config.get("temporal_dropout", 0.0)),
cond_dim=temporal_cond_dim,
).to(device)
def init_or_append_log(log_path: Path, resume: bool) -> None:
if resume and log_path.exists():
return
with open(log_path, "w", encoding="utf-8") as f:
f.write("epoch,step,loss,loss_cont,loss_disc\n")
def main():
args = parse_args()
if args.config:
print("using_config", str(Path(args.config).resolve()))
config = dict(DEFAULTS)
if args.config:
cfg_path = Path(args.config).resolve()
config.update(load_json(str(cfg_path)))
config = resolve_config_paths(config, cfg_path.parent)
else:
config = resolve_config_paths(config, BASE_DIR)
if args.device != "auto":
config["device"] = args.device
if args.out_dir:
out_dir = Path(args.out_dir)
if not out_dir.is_absolute():
base = Path(args.config).resolve().parent if args.config else BASE_DIR
out_dir = resolve_path(base, out_dir)
config["out_dir"] = str(out_dir)
if args.seed is not None:
config["seed"] = int(args.seed)
if bool(args.temporal_only):
config["use_temporal_stage1"] = True
config["epochs"] = 0
set_seed(int(config["seed"]))
split = load_split(config["split_path"])
time_col = split.get("time_column", "time")
cont_cols = [c for c in split["continuous"] if c != time_col]
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
type1_cols = resolve_routing_features(config, cont_cols, "type1_features")
type5_cols = resolve_routing_features(config, cont_cols, "type5_features")
type4_cols = resolve_taxonomy_features(config, cont_cols, "type4_features")
model_cont_cols = [c for c in cont_cols if c not in type1_cols and c not in type5_cols]
if not model_cont_cols:
raise SystemExit("model_cont_cols is empty; check routing_type1_features/routing_type5_features")
stats = load_json(config["stats_path"])
mean = stats["mean"]
std = stats["std"]
transforms = stats.get("transform", {})
raw_std = stats.get("raw_std", std)
quantile_probs = stats.get("quantile_probs")
quantile_values = stats.get("quantile_values")
use_quantile = bool(config.get("use_quantile_transform", False))
vocab = load_json(config["vocab_path"])["vocab"]
vocab_sizes = [len(vocab[c]) for c in disc_cols]
data_paths = None
if "data_glob" in config and config["data_glob"]:
data_paths = sorted(Path(config["data_glob"]).parent.glob(Path(config["data_glob"]).name))
if data_paths:
data_paths = [safe_path(p) for p in data_paths]
if not data_paths:
data_paths = [safe_path(config["data_path"])]
use_condition = bool(config.get("use_condition")) and config.get("condition_type") == "file_id"
cond_vocab_size = len(data_paths) if use_condition else 0
device = resolve_device(str(config["device"]))
print("device", device)
model = HybridDiffusionModel(
cont_dim=len(model_cont_cols),
disc_vocab_sizes=vocab_sizes,
time_dim=int(config.get("model_time_dim", 64)),
hidden_dim=int(config.get("model_hidden_dim", 256)),
num_layers=int(config.get("model_num_layers", 1)),
dropout=float(config.get("model_dropout", 0.0)),
ff_mult=int(config.get("model_ff_mult", 2)),
pos_dim=int(config.get("model_pos_dim", 64)),
use_pos_embed=bool(config.get("model_use_pos_embed", True)),
backbone_type=str(config.get("backbone_type", "gru")),
transformer_num_layers=int(config.get("transformer_num_layers", 4)),
transformer_nhead=int(config.get("transformer_nhead", 8)),
transformer_ff_dim=int(config.get("transformer_ff_dim", 2048)),
transformer_dropout=float(config.get("transformer_dropout", 0.1)),
cond_cont_dim=len(type1_cols),
cond_vocab_size=cond_vocab_size,
cond_dim=int(config.get("cond_dim", 32)),
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
eps_scale=float(config.get("eps_scale", 1.0)),
).to(device)
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
temporal_model = None
opt_temporal = None
temporal_use_type1_cond = bool(config.get("temporal_use_type1_cond", False))
temporal_cond_dim = len(type1_cols) if (temporal_use_type1_cond and type1_cols) else 0
temporal_focus_type4 = bool(config.get("temporal_focus_type4", False))
temporal_exclude_type4 = bool(config.get("temporal_exclude_type4", False))
type4_model_idx = [model_cont_cols.index(c) for c in type4_cols if c in model_cont_cols]
trend_mask = None
if temporal_focus_type4 and type4_model_idx:
trend_mask = torch.zeros(1, 1, len(model_cont_cols), device=device)
trend_mask[:, :, type4_model_idx] = 1.0
elif temporal_exclude_type4 and type4_model_idx:
trend_mask = torch.ones(1, 1, len(model_cont_cols), device=device)
trend_mask[:, :, type4_model_idx] = 0.0
if bool(config.get("use_temporal_stage1", False)):
temporal_model = build_temporal_model(config, model_cont_cols, temporal_cond_dim, device)
opt_temporal = torch.optim.Adam(
temporal_model.parameters(),
lr=float(config.get("temporal_lr", config["lr"])),
)
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
os.makedirs(config["out_dir"], exist_ok=True)
out_dir = Path(safe_path(config["out_dir"]))
log_path = out_dir / "train_log.csv"
init_or_append_log(log_path, args.resume)
with open(out_dir / "config_used.json", "w", encoding="utf-8") as f:
json.dump(config, f, indent=2)
main_ckpt_path = Path(args.resume_ckpt).resolve() if args.resume_ckpt else (out_dir / "model_ckpt.pt")
temporal_ckpt_path = out_dir / "temporal_ckpt.pt"
model_path = out_dir / "model.pt"
ema_path = out_dir / "model_ema.pt"
temporal_path = out_dir / "temporal.pt"
temporal_start_epoch = 0
temporal_start_step = 0
temporal_total_step = 0
main_start_epoch = 0
main_start_step = 0
total_step = 0
temporal_done = temporal_model is None
if args.resume:
if main_ckpt_path.exists():
ckpt = load_torch_state(str(main_ckpt_path), device)
model.load_state_dict(ckpt["model"])
opt.load_state_dict(ckpt["optim"])
total_step = int(ckpt.get("step", 0))
main_start_epoch = int(ckpt.get("epoch", 0))
main_start_step = int(ckpt.get("step_in_epoch", 0))
temporal_done = bool(ckpt.get("temporal_done", temporal_done))
temporal_total_step = int(ckpt.get("temporal_step", 0))
if ema is not None and ckpt.get("ema") is not None:
ema.shadow = ckpt["ema"]
if temporal_model is not None and ckpt.get("temporal") is not None:
temporal_model.load_state_dict(ckpt["temporal"])
if opt_temporal is not None and ckpt.get("temporal_optim") is not None:
opt_temporal.load_state_dict(ckpt["temporal_optim"])
print(f"resumed_main_ckpt epoch={main_start_epoch} step={main_start_step} total_step={total_step}")
elif temporal_ckpt_path.exists() and temporal_model is not None and opt_temporal is not None:
tckpt = load_torch_state(str(temporal_ckpt_path), device)
temporal_model.load_state_dict(tckpt["temporal"])
opt_temporal.load_state_dict(tckpt["temporal_optim"])
temporal_start_epoch = int(tckpt.get("epoch", 0))
temporal_start_step = int(tckpt.get("step_in_epoch", 0))
temporal_total_step = int(tckpt.get("temporal_step", 0))
print(
f"resumed_temporal_ckpt epoch={temporal_start_epoch} "
f"step={temporal_start_step} temporal_step={temporal_total_step}"
)
elif temporal_path.exists() and temporal_model is not None:
temporal_model.load_state_dict(load_torch_state(str(temporal_path), device))
temporal_done = True
print("reused_completed_temporal_stage", str(temporal_path))
if temporal_model is not None and opt_temporal is not None and not temporal_done:
for epoch in range(temporal_start_epoch, int(config.get("temporal_epochs", 1))):
skip_until = temporal_start_step if epoch == temporal_start_epoch else 0
for step, batch in enumerate(
windowed_batches(
data_paths,
cont_cols,
disc_cols,
vocab,
mean,
std,
batch_size=int(config["batch_size"]),
seq_len=int(config["seq_len"]),
max_batches=int(config["max_batches"]),
return_file_id=False,
transforms=transforms,
quantile_probs=quantile_probs,
quantile_values=quantile_values,
use_quantile=use_quantile,
shuffle_buffer=int(config.get("shuffle_buffer", 0)),
)
):
if step < skip_until:
continue
x_cont, _ = batch
x_cont = x_cont.to(device)
model_idx = [cont_cols.index(c) for c in model_cont_cols]
x_cont_model = x_cont[:, :, model_idx]
cond_cont = None
if temporal_cond_dim > 0:
cond_idx = [cont_cols.index(c) for c in type1_cols]
cond_cont = x_cont[:, :, cond_idx]
_, pred_next = temporal_model.forward_teacher(x_cont_model, cond_cont=cond_cont)
target_next = x_cont_model[:, 1:, :]
if trend_mask is not None:
mask = trend_mask.to(dtype=pred_next.dtype, device=pred_next.device)
mse = (pred_next - target_next) ** 2
temporal_loss = (mse * mask).sum() / torch.clamp(mask.sum() * mse.size(0) * mse.size(1), min=1.0)
else:
temporal_loss = F.mse_loss(pred_next, target_next)
opt_temporal.zero_grad()
temporal_loss.backward()
if float(config.get("grad_clip", 0.0)) > 0:
torch.nn.utils.clip_grad_norm_(temporal_model.parameters(), float(config["grad_clip"]))
opt_temporal.step()
temporal_total_step += 1
if step % int(config["log_every"]) == 0:
print("temporal_epoch", epoch, "step", step, "loss", float(temporal_loss))
if temporal_total_step % int(config["ckpt_every"]) == 0:
atomic_torch_save(
{
"temporal": temporal_model.state_dict(),
"temporal_optim": opt_temporal.state_dict(),
"epoch": epoch,
"step_in_epoch": step + 1,
"temporal_step": temporal_total_step,
"config": config,
},
temporal_ckpt_path,
)
temporal_start_step = 0
atomic_torch_save(
{
"temporal": temporal_model.state_dict(),
"temporal_optim": opt_temporal.state_dict(),
"epoch": epoch + 1,
"step_in_epoch": 0,
"temporal_step": temporal_total_step,
"config": config,
},
temporal_ckpt_path,
)
atomic_torch_save(temporal_model.state_dict(), temporal_path)
temporal_done = True
if bool(args.temporal_only):
return
for epoch in range(main_start_epoch, int(config["epochs"])):
skip_until = main_start_step if epoch == main_start_epoch else 0
for step, batch in enumerate(
windowed_batches(
data_paths,
cont_cols,
disc_cols,
vocab,
mean,
std,
batch_size=int(config["batch_size"]),
seq_len=int(config["seq_len"]),
max_batches=int(config["max_batches"]),
return_file_id=use_condition,
transforms=transforms,
quantile_probs=quantile_probs,
quantile_values=quantile_values,
use_quantile=use_quantile,
shuffle_buffer=int(config.get("shuffle_buffer", 0)),
)
):
if step < skip_until:
continue
if use_condition:
x_cont, x_disc, cond = batch
cond = cond.to(device)
else:
x_cont, x_disc = batch
cond = None
x_cont = x_cont.to(device)
x_disc = x_disc.to(device)
model_idx = [cont_cols.index(c) for c in model_cont_cols]
cond_idx = [cont_cols.index(c) for c in type1_cols] if type1_cols else []
x_cont_model = x_cont[:, :, model_idx]
cond_cont = x_cont[:, :, cond_idx] if cond_idx else None
trend = None
if temporal_model is not None:
temporal_model.eval()
with torch.no_grad():
trend, _ = temporal_model.forward_teacher(x_cont_model, cond_cont=cond_cont)
if trend_mask is not None and trend is not None:
trend = trend * trend_mask.to(dtype=trend.dtype, device=trend.device)
x_cont_resid = x_cont_model if trend is None else x_cont_model - trend
bsz = x_cont.size(0)
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
x_cont_t, noise = q_sample_continuous(x_cont_resid, t, alphas_cumprod)
mask_tokens = torch.tensor(vocab_sizes, device=device)
x_disc_t, mask = q_sample_discrete(
x_disc,
t,
mask_tokens,
int(config["timesteps"]),
mask_scale=float(config.get("disc_mask_scale", 1.0)),
)
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond, cond_cont=cond_cont)
cont_target = str(config.get("cont_target", "eps"))
if cont_target == "x0":
x0_target = x_cont_resid
if float(config.get("cont_clamp_x0", 0.0)) > 0:
x0_target = torch.clamp(
x0_target,
-float(config["cont_clamp_x0"]),
float(config["cont_clamp_x0"]),
)
loss_base = (eps_pred - x0_target) ** 2
else:
loss_base = (eps_pred - noise) ** 2
if config.get("cont_loss_weighting") == "inv_std":
weights = torch.tensor(
[1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in model_cont_cols],
device=device,
dtype=eps_pred.dtype,
).view(1, 1, -1)
loss_cont = (loss_base * weights).mean()
else:
loss_cont = loss_base.mean()
if bool(config.get("snr_weighted_loss", False)):
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
snr = a_bar_t / torch.clamp(1.0 - a_bar_t, min=1e-8)
gamma = float(config.get("snr_gamma", 1.0))
snr_weight = snr / (snr + gamma)
loss_cont = (loss_cont * snr_weight.mean()).mean()
loss_disc = 0.0
loss_disc_count = 0
for i, logit in enumerate(logits):
if mask[:, :, i].any():
loss_disc = loss_disc + F.cross_entropy(
logit[mask[:, :, i]],
x_disc[:, :, i][mask[:, :, i]],
)
loss_disc_count += 1
if loss_disc_count > 0:
loss_disc = loss_disc / loss_disc_count
lam = float(config["lambda"])
loss = lam * loss_cont + (1 - lam) * loss_disc
q_weight = float(config.get("quantile_loss_weight", 0.0))
if q_weight > 0:
q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95])
q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype)
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
x_real = x_cont_resid
if cont_target == "x0":
x_gen = eps_pred
else:
x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t)
x_real = x_real.view(-1, x_real.size(-1))
x_gen = x_gen.view(-1, x_gen.size(-1))
q_real = torch.quantile(x_real, q_tensor, dim=0)
q_gen = torch.quantile(x_gen, q_tensor, dim=0)
quantile_loss = torch.mean(torch.abs(q_gen - q_real))
loss = loss + q_weight * quantile_loss
stat_weight = float(config.get("residual_stat_weight", 0.0))
if stat_weight > 0:
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
if cont_target == "x0":
x_gen = eps_pred
else:
x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t)
x_real = x_cont_resid
mean_real = x_real.mean(dim=(0, 1))
mean_gen = x_gen.mean(dim=(0, 1))
std_real = x_real.std(dim=(0, 1))
std_gen = x_gen.std(dim=(0, 1))
stat_loss = F.mse_loss(mean_gen, mean_real) + F.mse_loss(std_gen, std_real)
loss = loss + stat_weight * stat_loss
opt.zero_grad()
loss.backward()
if float(config.get("grad_clip", 0.0)) > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"]))
opt.step()
if ema is not None:
ema.update(model)
if step % int(config["log_every"]) == 0:
print("epoch", epoch, "step", step, "loss", float(loss))
with open(log_path, "a", encoding="utf-8") as f:
f.write(
"%d,%d,%.6f,%.6f,%.6f\n"
% (epoch, step, float(loss), float(loss_cont), float(loss_disc))
)
total_step += 1
if total_step % int(config["ckpt_every"]) == 0:
ckpt = {
"model": model.state_dict(),
"optim": opt.state_dict(),
"config": config,
"step": total_step,
"epoch": epoch,
"step_in_epoch": step + 1,
"temporal_done": temporal_done,
"temporal_step": temporal_total_step,
}
if ema is not None:
ckpt["ema"] = ema.state_dict()
if temporal_model is not None:
ckpt["temporal"] = temporal_model.state_dict()
if opt_temporal is not None:
ckpt["temporal_optim"] = opt_temporal.state_dict()
atomic_torch_save(ckpt, main_ckpt_path)
main_start_step = 0
ckpt = {
"model": model.state_dict(),
"optim": opt.state_dict(),
"config": config,
"step": total_step,
"epoch": epoch + 1,
"step_in_epoch": 0,
"temporal_done": temporal_done,
"temporal_step": temporal_total_step,
}
if ema is not None:
ckpt["ema"] = ema.state_dict()
if temporal_model is not None:
ckpt["temporal"] = temporal_model.state_dict()
if opt_temporal is not None:
ckpt["temporal_optim"] = opt_temporal.state_dict()
atomic_torch_save(ckpt, main_ckpt_path)
atomic_torch_save(model.state_dict(), model_path)
if ema is not None:
atomic_torch_save(ema.state_dict(), ema_path)
if temporal_model is not None:
atomic_torch_save(temporal_model.state_dict(), temporal_path)
if __name__ == "__main__":
main()