Add resumable submission pipeline
This commit is contained in:
81
example/config_submission_full.json
Normal file
81
example/config_submission_full.json
Normal file
@@ -0,0 +1,81 @@
|
||||
{
|
||||
"data_path": "../../dataset/hai/hai-21.03/train1.csv.gz",
|
||||
"data_glob": "../../dataset/hai/hai-21.03/train*.csv.gz",
|
||||
"split_path": "./feature_split.json",
|
||||
"stats_path": "./results/cont_stats.json",
|
||||
"vocab_path": "./results/disc_vocab.json",
|
||||
"out_dir": "./results",
|
||||
"device": "auto",
|
||||
"timesteps": 600,
|
||||
"batch_size": 12,
|
||||
"seq_len": 96,
|
||||
"epochs": 10,
|
||||
"max_batches": 4000,
|
||||
"lambda": 0.7,
|
||||
"lr": 0.0005,
|
||||
"seed": 1337,
|
||||
"log_every": 10,
|
||||
"ckpt_every": 50,
|
||||
"ema_decay": 0.999,
|
||||
"use_ema": true,
|
||||
"clip_k": 5.0,
|
||||
"grad_clip": 1.0,
|
||||
"use_condition": true,
|
||||
"condition_type": "file_id",
|
||||
"cond_dim": 32,
|
||||
"use_tanh_eps": false,
|
||||
"eps_scale": 1.0,
|
||||
"model_time_dim": 128,
|
||||
"model_hidden_dim": 512,
|
||||
"model_num_layers": 2,
|
||||
"model_dropout": 0.1,
|
||||
"model_ff_mult": 2,
|
||||
"model_pos_dim": 64,
|
||||
"model_use_pos_embed": true,
|
||||
"backbone_type": "transformer",
|
||||
"transformer_num_layers": 3,
|
||||
"transformer_nhead": 4,
|
||||
"transformer_ff_dim": 512,
|
||||
"transformer_dropout": 0.1,
|
||||
"disc_mask_scale": 0.9,
|
||||
"cont_loss_weighting": "inv_std",
|
||||
"cont_loss_eps": 1e-6,
|
||||
"cont_target": "x0",
|
||||
"cont_clamp_x0": 5.0,
|
||||
"use_quantile_transform": true,
|
||||
"quantile_bins": 1001,
|
||||
"cont_bound_mode": "none",
|
||||
"cont_bound_strength": 2.0,
|
||||
"cont_post_calibrate": true,
|
||||
"cont_post_scale": {},
|
||||
"full_stats": true,
|
||||
"type1_features": ["P1_B4002", "P2_MSD", "P4_HT_LD", "P1_B2004", "P1_B3004", "P1_B4022", "P1_B3005"],
|
||||
"type2_features": ["P1_B4005"],
|
||||
"type3_features": ["P1_PCV02Z", "P1_PCV01Z", "P1_PCV01D", "P1_FCV02Z", "P1_FCV03D", "P1_FCV03Z", "P1_LCV01D", "P1_LCV01Z"],
|
||||
"type4_features": ["P1_PIT02", "P2_SIT02", "P1_FT03"],
|
||||
"type5_features": ["P1_FT03Z", "P1_FT02Z"],
|
||||
"type6_features": ["P4_HT_PO", "P2_24Vdc", "P2_HILout"],
|
||||
"routing_type1_features": ["P1_B4022"],
|
||||
"routing_type5_features": [],
|
||||
"shuffle_buffer": 256,
|
||||
"use_temporal_stage1": true,
|
||||
"temporal_backbone": "transformer",
|
||||
"temporal_hidden_dim": 256,
|
||||
"temporal_num_layers": 1,
|
||||
"temporal_dropout": 0.0,
|
||||
"temporal_pos_dim": 64,
|
||||
"temporal_use_pos_embed": true,
|
||||
"temporal_transformer_num_layers": 2,
|
||||
"temporal_transformer_nhead": 4,
|
||||
"temporal_transformer_ff_dim": 256,
|
||||
"temporal_transformer_dropout": 0.1,
|
||||
"temporal_epochs": 3,
|
||||
"temporal_lr": 0.001,
|
||||
"quantile_loss_weight": 0.2,
|
||||
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
|
||||
"snr_weighted_loss": true,
|
||||
"snr_gamma": 1.0,
|
||||
"residual_stat_weight": 0.05,
|
||||
"sample_batch_size": 4,
|
||||
"sample_seq_len": 96
|
||||
}
|
||||
440
example/export_samples_resume.py
Normal file
440
example/export_samples_resume.py
Normal file
@@ -0,0 +1,440 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Sample from a trained hybrid diffusion model with routing-aware export fixes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import gzip
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from data_utils import inverse_quantile_transform, load_split, normalize_cont, quantile_calibrate_to_real
|
||||
from export_samples import build_inverse_vocab, load_stats, load_torch_state, read_header
|
||||
from hybrid_diffusion import (
|
||||
HybridDiffusionModel,
|
||||
TemporalGRUGenerator,
|
||||
TemporalTransformerGenerator,
|
||||
cosine_beta_schedule,
|
||||
)
|
||||
from platform_utils import resolve_device, resolve_path
|
||||
from submission_type_utils import (
|
||||
denormalize_cont_tensor,
|
||||
resolve_routing_features,
|
||||
resolve_taxonomy_features,
|
||||
)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Sample and export HAI feature sequences with routing-aware fixes.")
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
repo_dir = base_dir.parent.parent
|
||||
parser.add_argument("--data-path", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"))
|
||||
parser.add_argument("--data-glob", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"))
|
||||
parser.add_argument("--split-path", default=str(base_dir / "feature_split.json"))
|
||||
parser.add_argument("--stats-path", default=str(base_dir / "results" / "cont_stats.json"))
|
||||
parser.add_argument("--vocab-path", default=str(base_dir / "results" / "disc_vocab.json"))
|
||||
parser.add_argument("--model-path", default=str(base_dir / "results" / "model.pt"))
|
||||
parser.add_argument("--out", default=str(base_dir / "results" / "generated.csv"))
|
||||
parser.add_argument("--timesteps", type=int, default=200)
|
||||
parser.add_argument("--seq-len", type=int, default=64)
|
||||
parser.add_argument("--batch-size", type=int, default=2)
|
||||
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
|
||||
parser.add_argument("--include-time", action="store_true", help="Include time column as a simple index")
|
||||
parser.add_argument("--clip-k", type=float, default=5.0, help="Clip continuous values to mean±k*std")
|
||||
parser.add_argument("--use-ema", action="store_true", help="Use EMA weights if available")
|
||||
parser.add_argument("--config", default=None, help="Optional config_used.json to infer conditioning")
|
||||
parser.add_argument("--condition-id", type=int, default=-1, help="Condition file id (0..N-1), -1=random")
|
||||
parser.add_argument("--include-condition", action="store_true", help="Include condition id column in CSV")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
args.data_path = str(resolve_path(base_dir, args.data_path))
|
||||
args.data_glob = str(resolve_path(base_dir, args.data_glob)) if args.data_glob else ""
|
||||
args.split_path = str(resolve_path(base_dir, args.split_path))
|
||||
args.stats_path = str(resolve_path(base_dir, args.stats_path))
|
||||
args.vocab_path = str(resolve_path(base_dir, args.vocab_path))
|
||||
args.model_path = str(resolve_path(base_dir, args.model_path))
|
||||
args.out = str(resolve_path(base_dir, args.out))
|
||||
|
||||
if not os.path.exists(args.model_path):
|
||||
raise SystemExit(f"missing model file: {args.model_path}")
|
||||
|
||||
data_path = args.data_path
|
||||
if args.data_glob:
|
||||
base = Path(args.data_glob).parent
|
||||
pat = Path(args.data_glob).name
|
||||
matches = sorted(base.glob(pat))
|
||||
if matches:
|
||||
data_path = str(matches[0])
|
||||
|
||||
split = load_split(args.split_path)
|
||||
time_col = split.get("time_column", "time")
|
||||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
||||
|
||||
stats = load_stats(args.stats_path)
|
||||
mean = stats["mean"]
|
||||
std = stats["std"]
|
||||
vmin = stats.get("min", {})
|
||||
vmax = stats.get("max", {})
|
||||
int_like = stats.get("int_like", {})
|
||||
max_decimals = stats.get("max_decimals", {})
|
||||
transforms = stats.get("transform", {})
|
||||
quantile_probs = stats.get("quantile_probs")
|
||||
quantile_values = stats.get("quantile_values")
|
||||
quantile_raw_values = stats.get("quantile_raw_values")
|
||||
|
||||
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
|
||||
vocab = vocab_json["vocab"]
|
||||
top_token = vocab_json.get("top_token", {})
|
||||
inv_vocab = build_inverse_vocab(vocab)
|
||||
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||
|
||||
device = resolve_device(args.device)
|
||||
cfg = {}
|
||||
use_condition = False
|
||||
cond_vocab_size = 0
|
||||
if args.config:
|
||||
args.config = str(resolve_path(base_dir, args.config))
|
||||
if args.config and os.path.exists(args.config):
|
||||
with open(args.config, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
|
||||
if use_condition:
|
||||
cfg_base = Path(args.config).resolve().parent
|
||||
cfg_glob = cfg.get("data_glob", args.data_glob)
|
||||
cfg_glob = str(resolve_path(cfg_base, cfg_glob))
|
||||
base = Path(cfg_glob).parent
|
||||
pat = Path(cfg_glob).name
|
||||
cond_vocab_size = len(sorted(base.glob(pat)))
|
||||
if cond_vocab_size <= 0:
|
||||
raise SystemExit(f"use_condition enabled but no files matched data_glob: {cfg_glob}")
|
||||
|
||||
cont_target = str(cfg.get("cont_target", "eps"))
|
||||
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
|
||||
use_quantile = bool(cfg.get("use_quantile_transform", False))
|
||||
cont_bound_mode = str(cfg.get("cont_bound_mode", "clamp"))
|
||||
cont_bound_strength = float(cfg.get("cont_bound_strength", 1.0))
|
||||
cont_post_scale = cfg.get("cont_post_scale", {}) if isinstance(cfg.get("cont_post_scale", {}), dict) else {}
|
||||
cont_post_calibrate = bool(cfg.get("cont_post_calibrate", False))
|
||||
|
||||
route_type1_cols = resolve_routing_features(cfg, cont_cols, "type1_features")
|
||||
route_type5_cols = resolve_routing_features(cfg, cont_cols, "type5_features")
|
||||
type4_cols = resolve_taxonomy_features(cfg, cont_cols, "type4_features")
|
||||
model_cont_cols = [c for c in cont_cols if c not in route_type1_cols and c not in route_type5_cols]
|
||||
|
||||
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
|
||||
temporal_use_type1_cond = bool(cfg.get("temporal_use_type1_cond", False))
|
||||
temporal_focus_type4 = bool(cfg.get("temporal_focus_type4", False))
|
||||
temporal_exclude_type4 = bool(cfg.get("temporal_exclude_type4", False))
|
||||
temporal_backbone = str(cfg.get("temporal_backbone", "gru"))
|
||||
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
|
||||
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
|
||||
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
|
||||
temporal_pos_dim = int(cfg.get("temporal_pos_dim", 64))
|
||||
temporal_use_pos_embed = bool(cfg.get("temporal_use_pos_embed", True))
|
||||
temporal_transformer_num_layers = int(cfg.get("temporal_transformer_num_layers", 2))
|
||||
temporal_transformer_nhead = int(cfg.get("temporal_transformer_nhead", 4))
|
||||
temporal_transformer_ff_dim = int(cfg.get("temporal_transformer_ff_dim", 512))
|
||||
temporal_transformer_dropout = float(cfg.get("temporal_transformer_dropout", 0.1))
|
||||
backbone_type = str(cfg.get("backbone_type", "gru"))
|
||||
transformer_num_layers = int(cfg.get("transformer_num_layers", 2))
|
||||
transformer_nhead = int(cfg.get("transformer_nhead", 4))
|
||||
transformer_ff_dim = int(cfg.get("transformer_ff_dim", 512))
|
||||
transformer_dropout = float(cfg.get("transformer_dropout", 0.1))
|
||||
|
||||
model = HybridDiffusionModel(
|
||||
cont_dim=len(model_cont_cols),
|
||||
disc_vocab_sizes=vocab_sizes,
|
||||
time_dim=int(cfg.get("model_time_dim", 64)),
|
||||
hidden_dim=int(cfg.get("model_hidden_dim", 256)),
|
||||
num_layers=int(cfg.get("model_num_layers", 1)),
|
||||
dropout=float(cfg.get("model_dropout", 0.0)),
|
||||
ff_mult=int(cfg.get("model_ff_mult", 2)),
|
||||
pos_dim=int(cfg.get("model_pos_dim", 64)),
|
||||
use_pos_embed=bool(cfg.get("model_use_pos_embed", True)),
|
||||
backbone_type=backbone_type,
|
||||
transformer_num_layers=transformer_num_layers,
|
||||
transformer_nhead=transformer_nhead,
|
||||
transformer_ff_dim=transformer_ff_dim,
|
||||
transformer_dropout=transformer_dropout,
|
||||
cond_cont_dim=len(route_type1_cols),
|
||||
cond_vocab_size=cond_vocab_size if use_condition else 0,
|
||||
cond_dim=int(cfg.get("cond_dim", 32)),
|
||||
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
|
||||
eps_scale=float(cfg.get("eps_scale", 1.0)),
|
||||
).to(device)
|
||||
if args.use_ema and os.path.exists(args.model_path.replace("model.pt", "model_ema.pt")):
|
||||
ema_path = args.model_path.replace("model.pt", "model_ema.pt")
|
||||
model.load_state_dict(load_torch_state(ema_path, device))
|
||||
else:
|
||||
model.load_state_dict(load_torch_state(args.model_path, device))
|
||||
model.eval()
|
||||
|
||||
temporal_model = None
|
||||
if use_temporal_stage1:
|
||||
temporal_path = Path(args.model_path).with_name("temporal.pt")
|
||||
if not temporal_path.exists():
|
||||
raise SystemExit(f"missing temporal model file: {temporal_path}")
|
||||
temporal_state = load_torch_state(str(temporal_path), device)
|
||||
temporal_cond_dim = len(route_type1_cols) if (temporal_use_type1_cond and route_type1_cols) else 0
|
||||
if isinstance(temporal_state, dict):
|
||||
if "in_proj.weight" in temporal_state:
|
||||
try:
|
||||
temporal_cond_dim = max(0, int(temporal_state["in_proj.weight"].shape[1]) - len(model_cont_cols))
|
||||
except Exception:
|
||||
pass
|
||||
elif "gru.weight_ih_l0" in temporal_state:
|
||||
try:
|
||||
temporal_cond_dim = max(0, int(temporal_state["gru.weight_ih_l0"].shape[1]) - len(model_cont_cols))
|
||||
except Exception:
|
||||
pass
|
||||
if temporal_backbone == "transformer":
|
||||
temporal_model = TemporalTransformerGenerator(
|
||||
input_dim=len(model_cont_cols),
|
||||
hidden_dim=temporal_hidden_dim,
|
||||
num_layers=temporal_transformer_num_layers,
|
||||
nhead=temporal_transformer_nhead,
|
||||
ff_dim=temporal_transformer_ff_dim,
|
||||
dropout=temporal_transformer_dropout,
|
||||
pos_dim=temporal_pos_dim,
|
||||
use_pos_embed=temporal_use_pos_embed,
|
||||
cond_dim=temporal_cond_dim,
|
||||
).to(device)
|
||||
else:
|
||||
temporal_model = TemporalGRUGenerator(
|
||||
input_dim=len(model_cont_cols),
|
||||
hidden_dim=temporal_hidden_dim,
|
||||
num_layers=temporal_num_layers,
|
||||
dropout=temporal_dropout,
|
||||
cond_dim=temporal_cond_dim,
|
||||
).to(device)
|
||||
temporal_model.load_state_dict(temporal_state)
|
||||
temporal_model.eval()
|
||||
|
||||
betas = cosine_beta_schedule(args.timesteps).to(device)
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
|
||||
x_cont = torch.randn(args.batch_size, args.seq_len, len(model_cont_cols), device=device)
|
||||
x_disc = torch.full((args.batch_size, args.seq_len, len(disc_cols)), 0, device=device, dtype=torch.long)
|
||||
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
||||
for i in range(len(disc_cols)):
|
||||
x_disc[:, :, i] = mask_tokens[i]
|
||||
|
||||
cond = None
|
||||
if use_condition:
|
||||
if cond_vocab_size <= 0:
|
||||
raise SystemExit("use_condition enabled but no files matched data_glob")
|
||||
if args.condition_id < 0:
|
||||
cond_id = torch.randint(0, cond_vocab_size, (args.batch_size,), device=device)
|
||||
else:
|
||||
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
|
||||
cond = cond_id
|
||||
|
||||
cond_cont = None
|
||||
if route_type1_cols:
|
||||
ref_glob = cfg.get("data_glob") or args.data_glob
|
||||
if ref_glob:
|
||||
ref_glob = str(resolve_path(Path(args.config).parent, ref_glob)) if args.config else ref_glob
|
||||
base = Path(ref_glob).parent
|
||||
pat = Path(ref_glob).name
|
||||
refs = sorted(base.glob(pat))
|
||||
if refs:
|
||||
ref_path = refs[0]
|
||||
ref_rows = []
|
||||
with gzip.open(ref_path, "rt", newline="") as fh:
|
||||
reader = csv.DictReader(fh)
|
||||
for row in reader:
|
||||
ref_rows.append(row)
|
||||
if len(ref_rows) >= args.seq_len:
|
||||
seq = ref_rows[: args.seq_len]
|
||||
cond_cont = torch.zeros(args.batch_size, args.seq_len, len(route_type1_cols), device=device)
|
||||
for t, row in enumerate(seq):
|
||||
for i, c in enumerate(route_type1_cols):
|
||||
cond_cont[:, t, i] = float(row[c])
|
||||
cond_cont = normalize_cont(
|
||||
cond_cont,
|
||||
route_type1_cols,
|
||||
mean,
|
||||
std,
|
||||
transforms=transforms,
|
||||
quantile_probs=quantile_probs,
|
||||
quantile_values=quantile_values,
|
||||
use_quantile=use_quantile,
|
||||
)
|
||||
|
||||
trend = None
|
||||
if temporal_model is not None:
|
||||
trend = temporal_model.generate(args.batch_size, args.seq_len, device, cond_cont=cond_cont)
|
||||
if temporal_focus_type4 and type4_cols:
|
||||
type4_model_idx = [model_cont_cols.index(c) for c in type4_cols if c in model_cont_cols]
|
||||
if type4_model_idx:
|
||||
trend_mask = torch.zeros(1, 1, len(model_cont_cols), device=device, dtype=trend.dtype)
|
||||
trend_mask[:, :, type4_model_idx] = 1.0
|
||||
trend = trend * trend_mask
|
||||
elif temporal_exclude_type4 and type4_cols:
|
||||
type4_model_idx = [model_cont_cols.index(c) for c in type4_cols if c in model_cont_cols]
|
||||
if type4_model_idx:
|
||||
trend_mask = torch.ones(1, 1, len(model_cont_cols), device=device, dtype=trend.dtype)
|
||||
trend_mask[:, :, type4_model_idx] = 0.0
|
||||
trend = trend * trend_mask
|
||||
|
||||
for t in reversed(range(args.timesteps)):
|
||||
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
|
||||
eps_pred, logits = model(x_cont, x_disc, t_batch, cond, cond_cont=cond_cont)
|
||||
|
||||
a_t = alphas[t]
|
||||
a_bar_t = alphas_cumprod[t]
|
||||
|
||||
if cont_target == "x0":
|
||||
x0_pred = eps_pred
|
||||
if cont_clamp_x0 > 0:
|
||||
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
|
||||
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
|
||||
coef1 = 1.0 / torch.sqrt(a_t)
|
||||
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
|
||||
mean_x = coef1 * (x_cont - coef2 * eps_pred)
|
||||
if t > 0:
|
||||
noise = torch.randn_like(x_cont)
|
||||
x_cont = mean_x + torch.sqrt(betas[t]) * noise
|
||||
else:
|
||||
x_cont = mean_x
|
||||
if args.clip_k > 0:
|
||||
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
|
||||
|
||||
for i, logit in enumerate(logits):
|
||||
if t == 0:
|
||||
probs = F.softmax(logit, dim=-1)
|
||||
x_disc[:, :, i] = torch.argmax(probs, dim=-1)
|
||||
else:
|
||||
mask = x_disc[:, :, i] == mask_tokens[i]
|
||||
if mask.any():
|
||||
probs = F.softmax(logit, dim=-1)
|
||||
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(
|
||||
args.batch_size, args.seq_len
|
||||
)
|
||||
x_disc[:, :, i][mask] = sampled[mask]
|
||||
|
||||
if trend is not None:
|
||||
x_cont = x_cont + trend
|
||||
x_cont = x_cont.cpu()
|
||||
x_disc = x_disc.cpu()
|
||||
|
||||
if args.clip_k > 0:
|
||||
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
|
||||
|
||||
if use_quantile:
|
||||
q_vals = {c: quantile_values[c] for c in model_cont_cols}
|
||||
x_cont = inverse_quantile_transform(x_cont, model_cont_cols, quantile_probs, q_vals)
|
||||
else:
|
||||
mean_vec = torch.tensor([mean[c] for c in model_cont_cols], dtype=x_cont.dtype)
|
||||
std_vec = torch.tensor([std[c] for c in model_cont_cols], dtype=x_cont.dtype)
|
||||
x_cont = x_cont * std_vec + mean_vec
|
||||
for i, c in enumerate(model_cont_cols):
|
||||
if transforms.get(c) == "log1p":
|
||||
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])
|
||||
if cont_post_calibrate and quantile_raw_values and quantile_probs:
|
||||
q_raw = {c: quantile_raw_values[c] for c in model_cont_cols}
|
||||
x_cont = quantile_calibrate_to_real(x_cont, model_cont_cols, quantile_probs, q_raw)
|
||||
if vmin and vmax:
|
||||
for i, c in enumerate(model_cont_cols):
|
||||
lo = vmin.get(c, None)
|
||||
hi = vmax.get(c, None)
|
||||
if lo is None or hi is None:
|
||||
continue
|
||||
lo = float(lo)
|
||||
hi = float(hi)
|
||||
if cont_bound_mode == "none":
|
||||
continue
|
||||
if cont_bound_mode == "sigmoid":
|
||||
x_cont[:, :, i] = lo + (hi - lo) * torch.sigmoid(x_cont[:, :, i])
|
||||
elif cont_bound_mode == "soft_tanh":
|
||||
mid = 0.5 * (lo + hi)
|
||||
half = 0.5 * (hi - lo)
|
||||
denom = cont_bound_strength if cont_bound_strength > 0 else 1.0
|
||||
x_cont[:, :, i] = mid + half * torch.tanh(x_cont[:, :, i] / denom)
|
||||
else:
|
||||
x_cont[:, :, i] = torch.clamp(x_cont[:, :, i], lo, hi)
|
||||
|
||||
if cont_post_scale:
|
||||
for i, c in enumerate(model_cont_cols):
|
||||
if c in cont_post_scale:
|
||||
try:
|
||||
scale = float(cont_post_scale[c])
|
||||
except Exception:
|
||||
scale = 1.0
|
||||
x_cont[:, :, i] = x_cont[:, :, i] * scale
|
||||
|
||||
full_cont = torch.zeros(args.batch_size, args.seq_len, len(cont_cols), dtype=x_cont.dtype)
|
||||
for i, c in enumerate(model_cont_cols):
|
||||
full_idx = cont_cols.index(c)
|
||||
full_cont[:, :, full_idx] = x_cont[:, :, i]
|
||||
if cond_cont is not None and route_type1_cols:
|
||||
cond_denorm = denormalize_cont_tensor(
|
||||
cond_cont.cpu(),
|
||||
route_type1_cols,
|
||||
mean,
|
||||
std,
|
||||
transforms=transforms,
|
||||
quantile_probs=quantile_probs,
|
||||
quantile_values=quantile_values,
|
||||
use_quantile=use_quantile,
|
||||
)
|
||||
for i, c in enumerate(route_type1_cols):
|
||||
full_idx = cont_cols.index(c)
|
||||
full_cont[:, :, full_idx] = cond_denorm[:, :, i]
|
||||
for c in route_type5_cols:
|
||||
if c.endswith("Z"):
|
||||
base = c[:-1]
|
||||
if base in cont_cols:
|
||||
bidx = cont_cols.index(base)
|
||||
cidx = cont_cols.index(c)
|
||||
full_cont[:, :, cidx] = full_cont[:, :, bidx]
|
||||
|
||||
header = read_header(data_path)
|
||||
out_cols = [c for c in header if c != time_col or args.include_time]
|
||||
if args.include_condition and use_condition:
|
||||
out_cols = ["__cond_file_id"] + out_cols
|
||||
|
||||
os.makedirs(os.path.dirname(args.out), exist_ok=True)
|
||||
with open(args.out, "w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=out_cols)
|
||||
writer.writeheader()
|
||||
|
||||
row_index = 0
|
||||
for b in range(args.batch_size):
|
||||
for t in range(args.seq_len):
|
||||
row = {}
|
||||
if args.include_condition and use_condition:
|
||||
row["__cond_file_id"] = str(int(cond[b].item())) if cond is not None else "-1"
|
||||
if args.include_time and time_col in header:
|
||||
row[time_col] = str(row_index)
|
||||
for i, c in enumerate(cont_cols):
|
||||
val = float(full_cont[b, t, i])
|
||||
if int_like.get(c, False):
|
||||
row[c] = str(int(round(val)))
|
||||
else:
|
||||
dec = int(max_decimals.get(c, 6))
|
||||
fmt = ("%%.%df" % dec) if dec > 0 else "%.0f"
|
||||
row[c] = fmt % val
|
||||
for i, c in enumerate(disc_cols):
|
||||
tok_idx = int(x_disc[b, t, i])
|
||||
tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "<UNK>"
|
||||
if tok == "<UNK>" and c in top_token:
|
||||
tok = top_token[c]
|
||||
row[c] = tok
|
||||
writer.writerow(row)
|
||||
row_index += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
19
example/run_submission_full.sh
Normal file
19
example/run_submission_full.sh
Normal file
@@ -0,0 +1,19 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
RUN_DIR="${RUN_DIR:-$SCRIPT_DIR/results/submission_full}"
|
||||
LOG_DIR="$RUN_DIR/logs"
|
||||
mkdir -p "$LOG_DIR"
|
||||
|
||||
STAMP="$(date '+%Y%m%d-%H%M%S')"
|
||||
LOG_FILE="$LOG_DIR/pipeline-$STAMP.log"
|
||||
|
||||
echo "[run_submission_full] run_dir=$RUN_DIR"
|
||||
echo "[run_submission_full] log_file=$LOG_FILE"
|
||||
|
||||
python "$SCRIPT_DIR/run_submission_resume.py" \
|
||||
--config "$SCRIPT_DIR/config_submission_full.json" \
|
||||
--device "${DEVICE:-cuda}" \
|
||||
--run-dir "$RUN_DIR" \
|
||||
"$@" 2>&1 | tee -a "$LOG_FILE"
|
||||
399
example/run_submission_resume.py
Normal file
399
example/run_submission_resume.py
Normal file
@@ -0,0 +1,399 @@
|
||||
#!/usr/bin/env python3
|
||||
"""One-command full pipeline runner with safe resume and stage skipping."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from platform_utils import is_windows, safe_path
|
||||
|
||||
|
||||
def run(cmd: List[str]) -> None:
|
||||
print("running:", " ".join(cmd))
|
||||
cmd = [safe_path(arg) for arg in cmd]
|
||||
if is_windows():
|
||||
subprocess.run(cmd, check=True, shell=False)
|
||||
else:
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
|
||||
def parse_args():
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
parser = argparse.ArgumentParser(description="Run prepare -> train -> export -> eval with resume-aware staging.")
|
||||
parser.add_argument("--config", default=str(base_dir / "config_submission_full.json"))
|
||||
parser.add_argument("--device", default="auto")
|
||||
parser.add_argument("--run-dir", default=str(base_dir / "results" / "submission_full"))
|
||||
parser.add_argument("--reference", default="")
|
||||
parser.add_argument("--no-resume", action="store_true", help="Do not auto-skip completed stages or resume from ckpt.")
|
||||
parser.add_argument("--skip-prepare", action="store_true")
|
||||
parser.add_argument("--skip-train", action="store_true")
|
||||
parser.add_argument("--skip-export", action="store_true")
|
||||
parser.add_argument("--skip-eval", action="store_true")
|
||||
parser.add_argument("--skip-comprehensive-eval", action="store_true")
|
||||
parser.add_argument("--skip-postprocess", action="store_true")
|
||||
parser.add_argument("--skip-post-eval", action="store_true")
|
||||
parser.add_argument("--skip-diagnostics", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_state(path: Path) -> Dict[str, str]:
|
||||
if not path.exists():
|
||||
return {}
|
||||
try:
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def save_state(path: Path, state: Dict[str, str]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(state, indent=2, sort_keys=True), encoding="utf-8")
|
||||
|
||||
|
||||
def stage_complete(state: Dict[str, str], stage: str, outputs: List[Path], resume: bool) -> bool:
|
||||
if not resume:
|
||||
return False
|
||||
if outputs and all(p.exists() for p in outputs):
|
||||
return True
|
||||
return state.get(stage) == "done"
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
config_path = Path(args.config)
|
||||
if not config_path.is_absolute():
|
||||
config_path = (base_dir / config_path).resolve()
|
||||
run_dir = Path(args.run_dir)
|
||||
if not run_dir.is_absolute():
|
||||
run_dir = (base_dir / run_dir).resolve()
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cfg = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
cfg_base = config_path.parent
|
||||
|
||||
def abs_cfg_like(value: str) -> str:
|
||||
p = Path(value)
|
||||
if p.is_absolute():
|
||||
return str(p)
|
||||
if any(ch in value for ch in ["*", "?", "["]):
|
||||
return str(cfg_base / p)
|
||||
return str((cfg_base / p).resolve())
|
||||
|
||||
ref = args.reference or cfg.get("data_glob") or cfg.get("data_path") or ""
|
||||
if ref:
|
||||
ref = abs_cfg_like(str(ref))
|
||||
|
||||
timesteps = int(cfg.get("timesteps", 200))
|
||||
seq_len = int(cfg.get("sample_seq_len", cfg.get("seq_len", 64)))
|
||||
batch_size = int(cfg.get("sample_batch_size", cfg.get("batch_size", 2)))
|
||||
clip_k = float(cfg.get("clip_k", 5.0))
|
||||
split_path = abs_cfg_like(str(cfg.get("split_path", "./feature_split.json")))
|
||||
stats_path = abs_cfg_like(str(cfg.get("stats_path", "./results/cont_stats.json")))
|
||||
vocab_path = abs_cfg_like(str(cfg.get("vocab_path", "./results/disc_vocab.json")))
|
||||
data_path = abs_cfg_like(str(cfg.get("data_path", ""))) if cfg.get("data_path") else ""
|
||||
data_glob = abs_cfg_like(str(cfg.get("data_glob", ""))) if cfg.get("data_glob") else ""
|
||||
|
||||
state_path = run_dir / "pipeline_state.json"
|
||||
state = load_state(state_path)
|
||||
resume = not args.no_resume
|
||||
cfg_for_steps = run_dir / "config_used.json"
|
||||
|
||||
stage_defs = []
|
||||
if not args.skip_prepare:
|
||||
stage_defs.append(
|
||||
(
|
||||
"prepare",
|
||||
[Path(stats_path), Path(vocab_path)],
|
||||
[sys.executable, str(base_dir / "prepare_data.py"), "--config", str(config_path)],
|
||||
)
|
||||
)
|
||||
if not args.skip_train:
|
||||
train_cmd = [
|
||||
sys.executable,
|
||||
str(base_dir / "train_resume.py"),
|
||||
"--config",
|
||||
str(config_path),
|
||||
"--device",
|
||||
args.device,
|
||||
"--out-dir",
|
||||
str(run_dir),
|
||||
"--seed",
|
||||
str(int(cfg.get("seed", 1337))),
|
||||
]
|
||||
if resume:
|
||||
train_cmd.append("--resume")
|
||||
stage_defs.append(("train", [run_dir / "model.pt"], train_cmd))
|
||||
if not args.skip_export:
|
||||
stage_defs.append(
|
||||
(
|
||||
"export",
|
||||
[run_dir / "generated.csv"],
|
||||
[
|
||||
sys.executable,
|
||||
str(base_dir / "export_samples_resume.py"),
|
||||
"--include-time",
|
||||
"--device",
|
||||
args.device,
|
||||
"--config",
|
||||
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
|
||||
"--data-path",
|
||||
str(data_path),
|
||||
"--data-glob",
|
||||
str(data_glob),
|
||||
"--split-path",
|
||||
str(split_path),
|
||||
"--stats-path",
|
||||
str(stats_path),
|
||||
"--vocab-path",
|
||||
str(vocab_path),
|
||||
"--model-path",
|
||||
str(run_dir / "model.pt"),
|
||||
"--out",
|
||||
str(run_dir / "generated.csv"),
|
||||
"--timesteps",
|
||||
str(timesteps),
|
||||
"--seq-len",
|
||||
str(seq_len),
|
||||
"--batch-size",
|
||||
str(batch_size),
|
||||
"--clip-k",
|
||||
str(clip_k),
|
||||
"--use-ema",
|
||||
],
|
||||
)
|
||||
)
|
||||
if not args.skip_eval:
|
||||
eval_cmd = [
|
||||
sys.executable,
|
||||
str(base_dir / "evaluate_generated.py"),
|
||||
"--generated",
|
||||
str(run_dir / "generated.csv"),
|
||||
"--split",
|
||||
str(split_path),
|
||||
"--stats",
|
||||
str(stats_path),
|
||||
"--vocab",
|
||||
str(vocab_path),
|
||||
"--out",
|
||||
str(run_dir / "eval.json"),
|
||||
]
|
||||
if ref:
|
||||
eval_cmd += ["--reference", str(ref)]
|
||||
stage_defs.append(("eval", [run_dir / "eval.json"], eval_cmd))
|
||||
if not args.skip_comprehensive_eval:
|
||||
stage_defs.append(
|
||||
(
|
||||
"comprehensive_eval",
|
||||
[run_dir / "comprehensive_eval.json"],
|
||||
[
|
||||
sys.executable,
|
||||
str(base_dir / "evaluate_comprehensive.py"),
|
||||
"--generated",
|
||||
str(run_dir / "generated.csv"),
|
||||
"--reference",
|
||||
str(config_path),
|
||||
"--config",
|
||||
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
|
||||
"--split",
|
||||
str(split_path),
|
||||
"--stats",
|
||||
str(stats_path),
|
||||
"--vocab",
|
||||
str(vocab_path),
|
||||
"--out",
|
||||
str(run_dir / "comprehensive_eval.json"),
|
||||
"--device",
|
||||
args.device,
|
||||
],
|
||||
)
|
||||
)
|
||||
if not args.skip_postprocess:
|
||||
post_cmd = [
|
||||
sys.executable,
|
||||
str(base_dir / "postprocess_types.py"),
|
||||
"--generated",
|
||||
str(run_dir / "generated.csv"),
|
||||
"--config",
|
||||
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
|
||||
"--out",
|
||||
str(run_dir / "generated_post.csv"),
|
||||
"--seed",
|
||||
str(int(cfg.get("seed", 1337))),
|
||||
]
|
||||
if ref:
|
||||
post_cmd += ["--reference", str(ref)]
|
||||
stage_defs.append(("postprocess", [run_dir / "generated_post.csv"], post_cmd))
|
||||
if not args.skip_post_eval:
|
||||
post_eval_cmd = [
|
||||
sys.executable,
|
||||
str(base_dir / "evaluate_generated.py"),
|
||||
"--generated",
|
||||
str(run_dir / "generated_post.csv"),
|
||||
"--split",
|
||||
str(split_path),
|
||||
"--stats",
|
||||
str(stats_path),
|
||||
"--vocab",
|
||||
str(vocab_path),
|
||||
"--out",
|
||||
str(run_dir / "eval_post.json"),
|
||||
]
|
||||
if ref:
|
||||
post_eval_cmd += ["--reference", str(ref)]
|
||||
stage_defs.append(("post_eval", [run_dir / "eval_post.json"], post_eval_cmd))
|
||||
if not args.skip_comprehensive_eval:
|
||||
stage_defs.append(
|
||||
(
|
||||
"comprehensive_post_eval",
|
||||
[run_dir / "comprehensive_eval_post.json"],
|
||||
[
|
||||
sys.executable,
|
||||
str(base_dir / "evaluate_comprehensive.py"),
|
||||
"--generated",
|
||||
str(run_dir / "generated_post.csv"),
|
||||
"--reference",
|
||||
str(config_path),
|
||||
"--config",
|
||||
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
|
||||
"--split",
|
||||
str(split_path),
|
||||
"--stats",
|
||||
str(stats_path),
|
||||
"--vocab",
|
||||
str(vocab_path),
|
||||
"--out",
|
||||
str(run_dir / "comprehensive_eval_post.json"),
|
||||
"--device",
|
||||
args.device,
|
||||
],
|
||||
)
|
||||
)
|
||||
if not args.skip_diagnostics:
|
||||
stage_defs.extend(
|
||||
[
|
||||
(
|
||||
"filtered_metrics",
|
||||
[run_dir / "filtered_metrics.json"],
|
||||
[
|
||||
sys.executable,
|
||||
str(base_dir / "filtered_metrics.py"),
|
||||
"--eval",
|
||||
str(run_dir / "eval.json"),
|
||||
"--out",
|
||||
str(run_dir / "filtered_metrics.json"),
|
||||
],
|
||||
),
|
||||
(
|
||||
"ranked_ks",
|
||||
[run_dir / "ranked_ks.csv"],
|
||||
[
|
||||
sys.executable,
|
||||
str(base_dir / "ranked_ks.py"),
|
||||
"--eval",
|
||||
str(run_dir / "eval.json"),
|
||||
"--out",
|
||||
str(run_dir / "ranked_ks.csv"),
|
||||
],
|
||||
),
|
||||
(
|
||||
"program_stats",
|
||||
[run_dir / "program_stats.json"],
|
||||
[
|
||||
sys.executable,
|
||||
str(base_dir / "program_stats.py"),
|
||||
"--generated",
|
||||
str(run_dir / "generated.csv"),
|
||||
"--reference",
|
||||
str(config_path),
|
||||
"--config",
|
||||
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
|
||||
],
|
||||
),
|
||||
(
|
||||
"controller_stats",
|
||||
[run_dir / "controller_stats.json"],
|
||||
[
|
||||
sys.executable,
|
||||
str(base_dir / "controller_stats.py"),
|
||||
"--generated",
|
||||
str(run_dir / "generated.csv"),
|
||||
"--reference",
|
||||
str(config_path),
|
||||
"--config",
|
||||
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
|
||||
],
|
||||
),
|
||||
(
|
||||
"actuator_stats",
|
||||
[run_dir / "actuator_stats.json"],
|
||||
[
|
||||
sys.executable,
|
||||
str(base_dir / "actuator_stats.py"),
|
||||
"--generated",
|
||||
str(run_dir / "generated.csv"),
|
||||
"--reference",
|
||||
str(config_path),
|
||||
"--config",
|
||||
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
|
||||
],
|
||||
),
|
||||
(
|
||||
"pv_stats",
|
||||
[run_dir / "pv_stats.json"],
|
||||
[
|
||||
sys.executable,
|
||||
str(base_dir / "pv_stats.py"),
|
||||
"--generated",
|
||||
str(run_dir / "generated.csv"),
|
||||
"--reference",
|
||||
str(config_path),
|
||||
"--config",
|
||||
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
|
||||
],
|
||||
),
|
||||
(
|
||||
"aux_stats",
|
||||
[run_dir / "aux_stats.json"],
|
||||
[
|
||||
sys.executable,
|
||||
str(base_dir / "aux_stats.py"),
|
||||
"--generated",
|
||||
str(run_dir / "generated.csv"),
|
||||
"--reference",
|
||||
str(config_path),
|
||||
"--config",
|
||||
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
command_log = run_dir / "run_commands.txt"
|
||||
if not command_log.exists():
|
||||
command_log.write_text("", encoding="utf-8")
|
||||
|
||||
for stage, outputs, cmd in stage_defs:
|
||||
if stage_complete(state, stage, outputs, resume):
|
||||
print(f"skip_stage {stage}: outputs already present")
|
||||
state[stage] = "done"
|
||||
save_state(state_path, state)
|
||||
continue
|
||||
state[stage] = "running"
|
||||
save_state(state_path, state)
|
||||
with command_log.open("a", encoding="utf-8") as fh:
|
||||
fh.write(stage + ": " + " ".join(cmd) + "\n")
|
||||
run(cmd)
|
||||
state[stage] = "done"
|
||||
save_state(state_path, state)
|
||||
|
||||
print(f"pipeline_complete run_dir={run_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
53
example/submission_type_utils.py
Normal file
53
example/submission_type_utils.py
Normal file
@@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Helpers for keeping type taxonomy and routing policy separate."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from data_utils import inverse_quantile_transform
|
||||
|
||||
|
||||
def resolve_taxonomy_features(config: Dict, cont_cols: List[str], base_key: str) -> List[str]:
|
||||
feats = config.get(base_key, []) or []
|
||||
return [c for c in feats if c in cont_cols]
|
||||
|
||||
|
||||
def resolve_routing_features(config: Dict, cont_cols: List[str], base_key: str) -> List[str]:
|
||||
feats = config.get(f"routing_{base_key}", config.get(base_key, [])) or []
|
||||
return [c for c in feats if c in cont_cols]
|
||||
|
||||
|
||||
def denormalize_cont_tensor(
|
||||
x: torch.Tensor,
|
||||
cont_cols: List[str],
|
||||
mean: Dict[str, float],
|
||||
std: Dict[str, float],
|
||||
transforms: Optional[Dict[str, str]] = None,
|
||||
quantile_probs: Optional[List[float]] = None,
|
||||
quantile_values: Optional[Dict[str, List[float]]] = None,
|
||||
use_quantile: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if x is None:
|
||||
raise ValueError("x must not be None")
|
||||
if not cont_cols:
|
||||
return x.clone()
|
||||
|
||||
out = x.clone()
|
||||
if use_quantile:
|
||||
if not quantile_probs or not quantile_values:
|
||||
raise ValueError("use_quantile=True but quantile stats are missing")
|
||||
q_vals = {c: quantile_values[c] for c in cont_cols}
|
||||
out = inverse_quantile_transform(out, cont_cols, quantile_probs, q_vals)
|
||||
else:
|
||||
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=out.dtype, device=out.device)
|
||||
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=out.dtype, device=out.device)
|
||||
out = out * std_vec + mean_vec
|
||||
|
||||
if transforms:
|
||||
for i, c in enumerate(cont_cols):
|
||||
if transforms.get(c) == "log1p":
|
||||
out[:, :, i] = torch.expm1(out[:, :, i])
|
||||
return out
|
||||
534
example/train_resume.py
Normal file
534
example/train_resume.py
Normal file
@@ -0,0 +1,534 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Train hybrid diffusion with checkpoint resume and selective type-aware routing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from data_utils import load_split, windowed_batches
|
||||
from hybrid_diffusion import (
|
||||
HybridDiffusionModel,
|
||||
TemporalGRUGenerator,
|
||||
TemporalTransformerGenerator,
|
||||
cosine_beta_schedule,
|
||||
q_sample_continuous,
|
||||
q_sample_discrete,
|
||||
)
|
||||
from platform_utils import resolve_device, resolve_path, safe_path
|
||||
from submission_type_utils import resolve_routing_features, resolve_taxonomy_features
|
||||
from train import DEFAULTS, EMA, load_json, resolve_config_paths, set_seed
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
def load_torch_state(path: str, device: str):
|
||||
try:
|
||||
return torch.load(path, map_location=device, weights_only=True)
|
||||
except TypeError:
|
||||
return torch.load(path, map_location=device)
|
||||
|
||||
|
||||
def atomic_torch_save(obj, path: Path) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = path.with_suffix(path.suffix + ".tmp")
|
||||
torch.save(obj, str(tmp))
|
||||
os.replace(str(tmp), str(path))
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Train hybrid diffusion on HAI with resume support.")
|
||||
parser.add_argument("--config", default=None, help="Path to JSON config.")
|
||||
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
|
||||
parser.add_argument("--out-dir", default=None, help="Override output directory")
|
||||
parser.add_argument("--seed", type=int, default=None, help="Override random seed")
|
||||
parser.add_argument("--temporal-only", action="store_true", help="Only train temporal stage-1 and exit.")
|
||||
parser.add_argument("--resume", action="store_true", help="Resume from checkpoint in out-dir if present.")
|
||||
parser.add_argument("--resume-ckpt", default=None, help="Optional explicit model checkpoint path.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def build_temporal_model(config: Dict, model_cont_cols, temporal_cond_dim: int, device: str):
|
||||
temporal_backbone = str(config.get("temporal_backbone", "gru"))
|
||||
if temporal_backbone == "transformer":
|
||||
return TemporalTransformerGenerator(
|
||||
input_dim=len(model_cont_cols),
|
||||
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
|
||||
num_layers=int(config.get("temporal_transformer_num_layers", 2)),
|
||||
nhead=int(config.get("temporal_transformer_nhead", 4)),
|
||||
ff_dim=int(config.get("temporal_transformer_ff_dim", 512)),
|
||||
dropout=float(config.get("temporal_transformer_dropout", 0.1)),
|
||||
pos_dim=int(config.get("temporal_pos_dim", 64)),
|
||||
use_pos_embed=bool(config.get("temporal_use_pos_embed", True)),
|
||||
cond_dim=temporal_cond_dim,
|
||||
).to(device)
|
||||
return TemporalGRUGenerator(
|
||||
input_dim=len(model_cont_cols),
|
||||
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
|
||||
num_layers=int(config.get("temporal_num_layers", 1)),
|
||||
dropout=float(config.get("temporal_dropout", 0.0)),
|
||||
cond_dim=temporal_cond_dim,
|
||||
).to(device)
|
||||
|
||||
|
||||
def init_or_append_log(log_path: Path, resume: bool) -> None:
|
||||
if resume and log_path.exists():
|
||||
return
|
||||
with open(log_path, "w", encoding="utf-8") as f:
|
||||
f.write("epoch,step,loss,loss_cont,loss_disc\n")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
if args.config:
|
||||
print("using_config", str(Path(args.config).resolve()))
|
||||
config = dict(DEFAULTS)
|
||||
if args.config:
|
||||
cfg_path = Path(args.config).resolve()
|
||||
config.update(load_json(str(cfg_path)))
|
||||
config = resolve_config_paths(config, cfg_path.parent)
|
||||
else:
|
||||
config = resolve_config_paths(config, BASE_DIR)
|
||||
|
||||
if args.device != "auto":
|
||||
config["device"] = args.device
|
||||
if args.out_dir:
|
||||
out_dir = Path(args.out_dir)
|
||||
if not out_dir.is_absolute():
|
||||
base = Path(args.config).resolve().parent if args.config else BASE_DIR
|
||||
out_dir = resolve_path(base, out_dir)
|
||||
config["out_dir"] = str(out_dir)
|
||||
if args.seed is not None:
|
||||
config["seed"] = int(args.seed)
|
||||
if bool(args.temporal_only):
|
||||
config["use_temporal_stage1"] = True
|
||||
config["epochs"] = 0
|
||||
|
||||
set_seed(int(config["seed"]))
|
||||
|
||||
split = load_split(config["split_path"])
|
||||
time_col = split.get("time_column", "time")
|
||||
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
||||
|
||||
type1_cols = resolve_routing_features(config, cont_cols, "type1_features")
|
||||
type5_cols = resolve_routing_features(config, cont_cols, "type5_features")
|
||||
type4_cols = resolve_taxonomy_features(config, cont_cols, "type4_features")
|
||||
model_cont_cols = [c for c in cont_cols if c not in type1_cols and c not in type5_cols]
|
||||
if not model_cont_cols:
|
||||
raise SystemExit("model_cont_cols is empty; check routing_type1_features/routing_type5_features")
|
||||
|
||||
stats = load_json(config["stats_path"])
|
||||
mean = stats["mean"]
|
||||
std = stats["std"]
|
||||
transforms = stats.get("transform", {})
|
||||
raw_std = stats.get("raw_std", std)
|
||||
quantile_probs = stats.get("quantile_probs")
|
||||
quantile_values = stats.get("quantile_values")
|
||||
use_quantile = bool(config.get("use_quantile_transform", False))
|
||||
|
||||
vocab = load_json(config["vocab_path"])["vocab"]
|
||||
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||
|
||||
data_paths = None
|
||||
if "data_glob" in config and config["data_glob"]:
|
||||
data_paths = sorted(Path(config["data_glob"]).parent.glob(Path(config["data_glob"]).name))
|
||||
if data_paths:
|
||||
data_paths = [safe_path(p) for p in data_paths]
|
||||
if not data_paths:
|
||||
data_paths = [safe_path(config["data_path"])]
|
||||
|
||||
use_condition = bool(config.get("use_condition")) and config.get("condition_type") == "file_id"
|
||||
cond_vocab_size = len(data_paths) if use_condition else 0
|
||||
|
||||
device = resolve_device(str(config["device"]))
|
||||
print("device", device)
|
||||
model = HybridDiffusionModel(
|
||||
cont_dim=len(model_cont_cols),
|
||||
disc_vocab_sizes=vocab_sizes,
|
||||
time_dim=int(config.get("model_time_dim", 64)),
|
||||
hidden_dim=int(config.get("model_hidden_dim", 256)),
|
||||
num_layers=int(config.get("model_num_layers", 1)),
|
||||
dropout=float(config.get("model_dropout", 0.0)),
|
||||
ff_mult=int(config.get("model_ff_mult", 2)),
|
||||
pos_dim=int(config.get("model_pos_dim", 64)),
|
||||
use_pos_embed=bool(config.get("model_use_pos_embed", True)),
|
||||
backbone_type=str(config.get("backbone_type", "gru")),
|
||||
transformer_num_layers=int(config.get("transformer_num_layers", 4)),
|
||||
transformer_nhead=int(config.get("transformer_nhead", 8)),
|
||||
transformer_ff_dim=int(config.get("transformer_ff_dim", 2048)),
|
||||
transformer_dropout=float(config.get("transformer_dropout", 0.1)),
|
||||
cond_cont_dim=len(type1_cols),
|
||||
cond_vocab_size=cond_vocab_size,
|
||||
cond_dim=int(config.get("cond_dim", 32)),
|
||||
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
|
||||
eps_scale=float(config.get("eps_scale", 1.0)),
|
||||
).to(device)
|
||||
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
|
||||
|
||||
temporal_model = None
|
||||
opt_temporal = None
|
||||
temporal_use_type1_cond = bool(config.get("temporal_use_type1_cond", False))
|
||||
temporal_cond_dim = len(type1_cols) if (temporal_use_type1_cond and type1_cols) else 0
|
||||
temporal_focus_type4 = bool(config.get("temporal_focus_type4", False))
|
||||
temporal_exclude_type4 = bool(config.get("temporal_exclude_type4", False))
|
||||
type4_model_idx = [model_cont_cols.index(c) for c in type4_cols if c in model_cont_cols]
|
||||
trend_mask = None
|
||||
if temporal_focus_type4 and type4_model_idx:
|
||||
trend_mask = torch.zeros(1, 1, len(model_cont_cols), device=device)
|
||||
trend_mask[:, :, type4_model_idx] = 1.0
|
||||
elif temporal_exclude_type4 and type4_model_idx:
|
||||
trend_mask = torch.ones(1, 1, len(model_cont_cols), device=device)
|
||||
trend_mask[:, :, type4_model_idx] = 0.0
|
||||
if bool(config.get("use_temporal_stage1", False)):
|
||||
temporal_model = build_temporal_model(config, model_cont_cols, temporal_cond_dim, device)
|
||||
opt_temporal = torch.optim.Adam(
|
||||
temporal_model.parameters(),
|
||||
lr=float(config.get("temporal_lr", config["lr"])),
|
||||
)
|
||||
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
|
||||
|
||||
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
|
||||
os.makedirs(config["out_dir"], exist_ok=True)
|
||||
out_dir = Path(safe_path(config["out_dir"]))
|
||||
log_path = out_dir / "train_log.csv"
|
||||
init_or_append_log(log_path, args.resume)
|
||||
|
||||
with open(out_dir / "config_used.json", "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
main_ckpt_path = Path(args.resume_ckpt).resolve() if args.resume_ckpt else (out_dir / "model_ckpt.pt")
|
||||
temporal_ckpt_path = out_dir / "temporal_ckpt.pt"
|
||||
model_path = out_dir / "model.pt"
|
||||
ema_path = out_dir / "model_ema.pt"
|
||||
temporal_path = out_dir / "temporal.pt"
|
||||
|
||||
temporal_start_epoch = 0
|
||||
temporal_start_step = 0
|
||||
temporal_total_step = 0
|
||||
main_start_epoch = 0
|
||||
main_start_step = 0
|
||||
total_step = 0
|
||||
temporal_done = temporal_model is None
|
||||
|
||||
if args.resume:
|
||||
if main_ckpt_path.exists():
|
||||
ckpt = load_torch_state(str(main_ckpt_path), device)
|
||||
model.load_state_dict(ckpt["model"])
|
||||
opt.load_state_dict(ckpt["optim"])
|
||||
total_step = int(ckpt.get("step", 0))
|
||||
main_start_epoch = int(ckpt.get("epoch", 0))
|
||||
main_start_step = int(ckpt.get("step_in_epoch", 0))
|
||||
temporal_done = bool(ckpt.get("temporal_done", temporal_done))
|
||||
temporal_total_step = int(ckpt.get("temporal_step", 0))
|
||||
if ema is not None and ckpt.get("ema") is not None:
|
||||
ema.shadow = ckpt["ema"]
|
||||
if temporal_model is not None and ckpt.get("temporal") is not None:
|
||||
temporal_model.load_state_dict(ckpt["temporal"])
|
||||
if opt_temporal is not None and ckpt.get("temporal_optim") is not None:
|
||||
opt_temporal.load_state_dict(ckpt["temporal_optim"])
|
||||
print(f"resumed_main_ckpt epoch={main_start_epoch} step={main_start_step} total_step={total_step}")
|
||||
elif temporal_ckpt_path.exists() and temporal_model is not None and opt_temporal is not None:
|
||||
tckpt = load_torch_state(str(temporal_ckpt_path), device)
|
||||
temporal_model.load_state_dict(tckpt["temporal"])
|
||||
opt_temporal.load_state_dict(tckpt["temporal_optim"])
|
||||
temporal_start_epoch = int(tckpt.get("epoch", 0))
|
||||
temporal_start_step = int(tckpt.get("step_in_epoch", 0))
|
||||
temporal_total_step = int(tckpt.get("temporal_step", 0))
|
||||
print(
|
||||
f"resumed_temporal_ckpt epoch={temporal_start_epoch} "
|
||||
f"step={temporal_start_step} temporal_step={temporal_total_step}"
|
||||
)
|
||||
elif temporal_path.exists() and temporal_model is not None:
|
||||
temporal_model.load_state_dict(load_torch_state(str(temporal_path), device))
|
||||
temporal_done = True
|
||||
print("reused_completed_temporal_stage", str(temporal_path))
|
||||
|
||||
if temporal_model is not None and opt_temporal is not None and not temporal_done:
|
||||
for epoch in range(temporal_start_epoch, int(config.get("temporal_epochs", 1))):
|
||||
skip_until = temporal_start_step if epoch == temporal_start_epoch else 0
|
||||
for step, batch in enumerate(
|
||||
windowed_batches(
|
||||
data_paths,
|
||||
cont_cols,
|
||||
disc_cols,
|
||||
vocab,
|
||||
mean,
|
||||
std,
|
||||
batch_size=int(config["batch_size"]),
|
||||
seq_len=int(config["seq_len"]),
|
||||
max_batches=int(config["max_batches"]),
|
||||
return_file_id=False,
|
||||
transforms=transforms,
|
||||
quantile_probs=quantile_probs,
|
||||
quantile_values=quantile_values,
|
||||
use_quantile=use_quantile,
|
||||
shuffle_buffer=int(config.get("shuffle_buffer", 0)),
|
||||
)
|
||||
):
|
||||
if step < skip_until:
|
||||
continue
|
||||
x_cont, _ = batch
|
||||
x_cont = x_cont.to(device)
|
||||
model_idx = [cont_cols.index(c) for c in model_cont_cols]
|
||||
x_cont_model = x_cont[:, :, model_idx]
|
||||
cond_cont = None
|
||||
if temporal_cond_dim > 0:
|
||||
cond_idx = [cont_cols.index(c) for c in type1_cols]
|
||||
cond_cont = x_cont[:, :, cond_idx]
|
||||
_, pred_next = temporal_model.forward_teacher(x_cont_model, cond_cont=cond_cont)
|
||||
target_next = x_cont_model[:, 1:, :]
|
||||
if trend_mask is not None:
|
||||
mask = trend_mask.to(dtype=pred_next.dtype, device=pred_next.device)
|
||||
mse = (pred_next - target_next) ** 2
|
||||
temporal_loss = (mse * mask).sum() / torch.clamp(mask.sum() * mse.size(0) * mse.size(1), min=1.0)
|
||||
else:
|
||||
temporal_loss = F.mse_loss(pred_next, target_next)
|
||||
opt_temporal.zero_grad()
|
||||
temporal_loss.backward()
|
||||
if float(config.get("grad_clip", 0.0)) > 0:
|
||||
torch.nn.utils.clip_grad_norm_(temporal_model.parameters(), float(config["grad_clip"]))
|
||||
opt_temporal.step()
|
||||
temporal_total_step += 1
|
||||
if step % int(config["log_every"]) == 0:
|
||||
print("temporal_epoch", epoch, "step", step, "loss", float(temporal_loss))
|
||||
if temporal_total_step % int(config["ckpt_every"]) == 0:
|
||||
atomic_torch_save(
|
||||
{
|
||||
"temporal": temporal_model.state_dict(),
|
||||
"temporal_optim": opt_temporal.state_dict(),
|
||||
"epoch": epoch,
|
||||
"step_in_epoch": step + 1,
|
||||
"temporal_step": temporal_total_step,
|
||||
"config": config,
|
||||
},
|
||||
temporal_ckpt_path,
|
||||
)
|
||||
temporal_start_step = 0
|
||||
atomic_torch_save(
|
||||
{
|
||||
"temporal": temporal_model.state_dict(),
|
||||
"temporal_optim": opt_temporal.state_dict(),
|
||||
"epoch": epoch + 1,
|
||||
"step_in_epoch": 0,
|
||||
"temporal_step": temporal_total_step,
|
||||
"config": config,
|
||||
},
|
||||
temporal_ckpt_path,
|
||||
)
|
||||
atomic_torch_save(temporal_model.state_dict(), temporal_path)
|
||||
temporal_done = True
|
||||
|
||||
if bool(args.temporal_only):
|
||||
return
|
||||
|
||||
for epoch in range(main_start_epoch, int(config["epochs"])):
|
||||
skip_until = main_start_step if epoch == main_start_epoch else 0
|
||||
for step, batch in enumerate(
|
||||
windowed_batches(
|
||||
data_paths,
|
||||
cont_cols,
|
||||
disc_cols,
|
||||
vocab,
|
||||
mean,
|
||||
std,
|
||||
batch_size=int(config["batch_size"]),
|
||||
seq_len=int(config["seq_len"]),
|
||||
max_batches=int(config["max_batches"]),
|
||||
return_file_id=use_condition,
|
||||
transforms=transforms,
|
||||
quantile_probs=quantile_probs,
|
||||
quantile_values=quantile_values,
|
||||
use_quantile=use_quantile,
|
||||
shuffle_buffer=int(config.get("shuffle_buffer", 0)),
|
||||
)
|
||||
):
|
||||
if step < skip_until:
|
||||
continue
|
||||
if use_condition:
|
||||
x_cont, x_disc, cond = batch
|
||||
cond = cond.to(device)
|
||||
else:
|
||||
x_cont, x_disc = batch
|
||||
cond = None
|
||||
x_cont = x_cont.to(device)
|
||||
x_disc = x_disc.to(device)
|
||||
|
||||
model_idx = [cont_cols.index(c) for c in model_cont_cols]
|
||||
cond_idx = [cont_cols.index(c) for c in type1_cols] if type1_cols else []
|
||||
x_cont_model = x_cont[:, :, model_idx]
|
||||
cond_cont = x_cont[:, :, cond_idx] if cond_idx else None
|
||||
|
||||
trend = None
|
||||
if temporal_model is not None:
|
||||
temporal_model.eval()
|
||||
with torch.no_grad():
|
||||
trend, _ = temporal_model.forward_teacher(x_cont_model, cond_cont=cond_cont)
|
||||
if trend_mask is not None and trend is not None:
|
||||
trend = trend * trend_mask.to(dtype=trend.dtype, device=trend.device)
|
||||
x_cont_resid = x_cont_model if trend is None else x_cont_model - trend
|
||||
|
||||
bsz = x_cont.size(0)
|
||||
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
|
||||
|
||||
x_cont_t, noise = q_sample_continuous(x_cont_resid, t, alphas_cumprod)
|
||||
|
||||
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
||||
x_disc_t, mask = q_sample_discrete(
|
||||
x_disc,
|
||||
t,
|
||||
mask_tokens,
|
||||
int(config["timesteps"]),
|
||||
mask_scale=float(config.get("disc_mask_scale", 1.0)),
|
||||
)
|
||||
|
||||
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond, cond_cont=cond_cont)
|
||||
|
||||
cont_target = str(config.get("cont_target", "eps"))
|
||||
if cont_target == "x0":
|
||||
x0_target = x_cont_resid
|
||||
if float(config.get("cont_clamp_x0", 0.0)) > 0:
|
||||
x0_target = torch.clamp(
|
||||
x0_target,
|
||||
-float(config["cont_clamp_x0"]),
|
||||
float(config["cont_clamp_x0"]),
|
||||
)
|
||||
loss_base = (eps_pred - x0_target) ** 2
|
||||
else:
|
||||
loss_base = (eps_pred - noise) ** 2
|
||||
|
||||
if config.get("cont_loss_weighting") == "inv_std":
|
||||
weights = torch.tensor(
|
||||
[1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in model_cont_cols],
|
||||
device=device,
|
||||
dtype=eps_pred.dtype,
|
||||
).view(1, 1, -1)
|
||||
loss_cont = (loss_base * weights).mean()
|
||||
else:
|
||||
loss_cont = loss_base.mean()
|
||||
|
||||
if bool(config.get("snr_weighted_loss", False)):
|
||||
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
|
||||
snr = a_bar_t / torch.clamp(1.0 - a_bar_t, min=1e-8)
|
||||
gamma = float(config.get("snr_gamma", 1.0))
|
||||
snr_weight = snr / (snr + gamma)
|
||||
loss_cont = (loss_cont * snr_weight.mean()).mean()
|
||||
loss_disc = 0.0
|
||||
loss_disc_count = 0
|
||||
for i, logit in enumerate(logits):
|
||||
if mask[:, :, i].any():
|
||||
loss_disc = loss_disc + F.cross_entropy(
|
||||
logit[mask[:, :, i]],
|
||||
x_disc[:, :, i][mask[:, :, i]],
|
||||
)
|
||||
loss_disc_count += 1
|
||||
if loss_disc_count > 0:
|
||||
loss_disc = loss_disc / loss_disc_count
|
||||
|
||||
lam = float(config["lambda"])
|
||||
loss = lam * loss_cont + (1 - lam) * loss_disc
|
||||
|
||||
q_weight = float(config.get("quantile_loss_weight", 0.0))
|
||||
if q_weight > 0:
|
||||
q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95])
|
||||
q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype)
|
||||
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
|
||||
x_real = x_cont_resid
|
||||
if cont_target == "x0":
|
||||
x_gen = eps_pred
|
||||
else:
|
||||
x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t)
|
||||
x_real = x_real.view(-1, x_real.size(-1))
|
||||
x_gen = x_gen.view(-1, x_gen.size(-1))
|
||||
q_real = torch.quantile(x_real, q_tensor, dim=0)
|
||||
q_gen = torch.quantile(x_gen, q_tensor, dim=0)
|
||||
quantile_loss = torch.mean(torch.abs(q_gen - q_real))
|
||||
loss = loss + q_weight * quantile_loss
|
||||
|
||||
stat_weight = float(config.get("residual_stat_weight", 0.0))
|
||||
if stat_weight > 0:
|
||||
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
|
||||
if cont_target == "x0":
|
||||
x_gen = eps_pred
|
||||
else:
|
||||
x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t)
|
||||
x_real = x_cont_resid
|
||||
mean_real = x_real.mean(dim=(0, 1))
|
||||
mean_gen = x_gen.mean(dim=(0, 1))
|
||||
std_real = x_real.std(dim=(0, 1))
|
||||
std_gen = x_gen.std(dim=(0, 1))
|
||||
stat_loss = F.mse_loss(mean_gen, mean_real) + F.mse_loss(std_gen, std_real)
|
||||
loss = loss + stat_weight * stat_loss
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
if float(config.get("grad_clip", 0.0)) > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"]))
|
||||
opt.step()
|
||||
if ema is not None:
|
||||
ema.update(model)
|
||||
|
||||
if step % int(config["log_every"]) == 0:
|
||||
print("epoch", epoch, "step", step, "loss", float(loss))
|
||||
with open(log_path, "a", encoding="utf-8") as f:
|
||||
f.write(
|
||||
"%d,%d,%.6f,%.6f,%.6f\n"
|
||||
% (epoch, step, float(loss), float(loss_cont), float(loss_disc))
|
||||
)
|
||||
|
||||
total_step += 1
|
||||
if total_step % int(config["ckpt_every"]) == 0:
|
||||
ckpt = {
|
||||
"model": model.state_dict(),
|
||||
"optim": opt.state_dict(),
|
||||
"config": config,
|
||||
"step": total_step,
|
||||
"epoch": epoch,
|
||||
"step_in_epoch": step + 1,
|
||||
"temporal_done": temporal_done,
|
||||
"temporal_step": temporal_total_step,
|
||||
}
|
||||
if ema is not None:
|
||||
ckpt["ema"] = ema.state_dict()
|
||||
if temporal_model is not None:
|
||||
ckpt["temporal"] = temporal_model.state_dict()
|
||||
if opt_temporal is not None:
|
||||
ckpt["temporal_optim"] = opt_temporal.state_dict()
|
||||
atomic_torch_save(ckpt, main_ckpt_path)
|
||||
|
||||
main_start_step = 0
|
||||
ckpt = {
|
||||
"model": model.state_dict(),
|
||||
"optim": opt.state_dict(),
|
||||
"config": config,
|
||||
"step": total_step,
|
||||
"epoch": epoch + 1,
|
||||
"step_in_epoch": 0,
|
||||
"temporal_done": temporal_done,
|
||||
"temporal_step": temporal_total_step,
|
||||
}
|
||||
if ema is not None:
|
||||
ckpt["ema"] = ema.state_dict()
|
||||
if temporal_model is not None:
|
||||
ckpt["temporal"] = temporal_model.state_dict()
|
||||
if opt_temporal is not None:
|
||||
ckpt["temporal_optim"] = opt_temporal.state_dict()
|
||||
atomic_torch_save(ckpt, main_ckpt_path)
|
||||
|
||||
atomic_torch_save(model.state_dict(), model_path)
|
||||
if ema is not None:
|
||||
atomic_torch_save(ema.state_dict(), ema_path)
|
||||
if temporal_model is not None:
|
||||
atomic_torch_save(temporal_model.state_dict(), temporal_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user