535 lines
23 KiB
Python
535 lines
23 KiB
Python
#!/usr/bin/env python3
|
|
"""Train hybrid diffusion with checkpoint resume and selective type-aware routing."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Dict
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from data_utils import load_split, windowed_batches
|
|
from hybrid_diffusion import (
|
|
HybridDiffusionModel,
|
|
TemporalGRUGenerator,
|
|
TemporalTransformerGenerator,
|
|
cosine_beta_schedule,
|
|
q_sample_continuous,
|
|
q_sample_discrete,
|
|
)
|
|
from platform_utils import resolve_device, resolve_path, safe_path
|
|
from submission_type_utils import resolve_routing_features, resolve_taxonomy_features
|
|
from train import DEFAULTS, EMA, load_json, resolve_config_paths, set_seed
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent
|
|
|
|
|
|
def load_torch_state(path: str, device: str):
|
|
try:
|
|
return torch.load(path, map_location=device, weights_only=True)
|
|
except TypeError:
|
|
return torch.load(path, map_location=device)
|
|
|
|
|
|
def atomic_torch_save(obj, path: Path) -> None:
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
tmp = path.with_suffix(path.suffix + ".tmp")
|
|
torch.save(obj, str(tmp))
|
|
os.replace(str(tmp), str(path))
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Train hybrid diffusion on HAI with resume support.")
|
|
parser.add_argument("--config", default=None, help="Path to JSON config.")
|
|
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
|
|
parser.add_argument("--out-dir", default=None, help="Override output directory")
|
|
parser.add_argument("--seed", type=int, default=None, help="Override random seed")
|
|
parser.add_argument("--temporal-only", action="store_true", help="Only train temporal stage-1 and exit.")
|
|
parser.add_argument("--resume", action="store_true", help="Resume from checkpoint in out-dir if present.")
|
|
parser.add_argument("--resume-ckpt", default=None, help="Optional explicit model checkpoint path.")
|
|
return parser.parse_args()
|
|
|
|
|
|
def build_temporal_model(config: Dict, model_cont_cols, temporal_cond_dim: int, device: str):
|
|
temporal_backbone = str(config.get("temporal_backbone", "gru"))
|
|
if temporal_backbone == "transformer":
|
|
return TemporalTransformerGenerator(
|
|
input_dim=len(model_cont_cols),
|
|
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
|
|
num_layers=int(config.get("temporal_transformer_num_layers", 2)),
|
|
nhead=int(config.get("temporal_transformer_nhead", 4)),
|
|
ff_dim=int(config.get("temporal_transformer_ff_dim", 512)),
|
|
dropout=float(config.get("temporal_transformer_dropout", 0.1)),
|
|
pos_dim=int(config.get("temporal_pos_dim", 64)),
|
|
use_pos_embed=bool(config.get("temporal_use_pos_embed", True)),
|
|
cond_dim=temporal_cond_dim,
|
|
).to(device)
|
|
return TemporalGRUGenerator(
|
|
input_dim=len(model_cont_cols),
|
|
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
|
|
num_layers=int(config.get("temporal_num_layers", 1)),
|
|
dropout=float(config.get("temporal_dropout", 0.0)),
|
|
cond_dim=temporal_cond_dim,
|
|
).to(device)
|
|
|
|
|
|
def init_or_append_log(log_path: Path, resume: bool) -> None:
|
|
if resume and log_path.exists():
|
|
return
|
|
with open(log_path, "w", encoding="utf-8") as f:
|
|
f.write("epoch,step,loss,loss_cont,loss_disc\n")
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
if args.config:
|
|
print("using_config", str(Path(args.config).resolve()))
|
|
config = dict(DEFAULTS)
|
|
if args.config:
|
|
cfg_path = Path(args.config).resolve()
|
|
config.update(load_json(str(cfg_path)))
|
|
config = resolve_config_paths(config, cfg_path.parent)
|
|
else:
|
|
config = resolve_config_paths(config, BASE_DIR)
|
|
|
|
if args.device != "auto":
|
|
config["device"] = args.device
|
|
if args.out_dir:
|
|
out_dir = Path(args.out_dir)
|
|
if not out_dir.is_absolute():
|
|
base = Path(args.config).resolve().parent if args.config else BASE_DIR
|
|
out_dir = resolve_path(base, out_dir)
|
|
config["out_dir"] = str(out_dir)
|
|
if args.seed is not None:
|
|
config["seed"] = int(args.seed)
|
|
if bool(args.temporal_only):
|
|
config["use_temporal_stage1"] = True
|
|
config["epochs"] = 0
|
|
|
|
set_seed(int(config["seed"]))
|
|
|
|
split = load_split(config["split_path"])
|
|
time_col = split.get("time_column", "time")
|
|
cont_cols = [c for c in split["continuous"] if c != time_col]
|
|
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
|
|
|
type1_cols = resolve_routing_features(config, cont_cols, "type1_features")
|
|
type5_cols = resolve_routing_features(config, cont_cols, "type5_features")
|
|
type4_cols = resolve_taxonomy_features(config, cont_cols, "type4_features")
|
|
model_cont_cols = [c for c in cont_cols if c not in type1_cols and c not in type5_cols]
|
|
if not model_cont_cols:
|
|
raise SystemExit("model_cont_cols is empty; check routing_type1_features/routing_type5_features")
|
|
|
|
stats = load_json(config["stats_path"])
|
|
mean = stats["mean"]
|
|
std = stats["std"]
|
|
transforms = stats.get("transform", {})
|
|
raw_std = stats.get("raw_std", std)
|
|
quantile_probs = stats.get("quantile_probs")
|
|
quantile_values = stats.get("quantile_values")
|
|
use_quantile = bool(config.get("use_quantile_transform", False))
|
|
|
|
vocab = load_json(config["vocab_path"])["vocab"]
|
|
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
|
|
|
data_paths = None
|
|
if "data_glob" in config and config["data_glob"]:
|
|
data_paths = sorted(Path(config["data_glob"]).parent.glob(Path(config["data_glob"]).name))
|
|
if data_paths:
|
|
data_paths = [safe_path(p) for p in data_paths]
|
|
if not data_paths:
|
|
data_paths = [safe_path(config["data_path"])]
|
|
|
|
use_condition = bool(config.get("use_condition")) and config.get("condition_type") == "file_id"
|
|
cond_vocab_size = len(data_paths) if use_condition else 0
|
|
|
|
device = resolve_device(str(config["device"]))
|
|
print("device", device)
|
|
model = HybridDiffusionModel(
|
|
cont_dim=len(model_cont_cols),
|
|
disc_vocab_sizes=vocab_sizes,
|
|
time_dim=int(config.get("model_time_dim", 64)),
|
|
hidden_dim=int(config.get("model_hidden_dim", 256)),
|
|
num_layers=int(config.get("model_num_layers", 1)),
|
|
dropout=float(config.get("model_dropout", 0.0)),
|
|
ff_mult=int(config.get("model_ff_mult", 2)),
|
|
pos_dim=int(config.get("model_pos_dim", 64)),
|
|
use_pos_embed=bool(config.get("model_use_pos_embed", True)),
|
|
backbone_type=str(config.get("backbone_type", "gru")),
|
|
transformer_num_layers=int(config.get("transformer_num_layers", 4)),
|
|
transformer_nhead=int(config.get("transformer_nhead", 8)),
|
|
transformer_ff_dim=int(config.get("transformer_ff_dim", 2048)),
|
|
transformer_dropout=float(config.get("transformer_dropout", 0.1)),
|
|
cond_cont_dim=len(type1_cols),
|
|
cond_vocab_size=cond_vocab_size,
|
|
cond_dim=int(config.get("cond_dim", 32)),
|
|
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
|
|
eps_scale=float(config.get("eps_scale", 1.0)),
|
|
).to(device)
|
|
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
|
|
|
|
temporal_model = None
|
|
opt_temporal = None
|
|
temporal_use_type1_cond = bool(config.get("temporal_use_type1_cond", False))
|
|
temporal_cond_dim = len(type1_cols) if (temporal_use_type1_cond and type1_cols) else 0
|
|
temporal_focus_type4 = bool(config.get("temporal_focus_type4", False))
|
|
temporal_exclude_type4 = bool(config.get("temporal_exclude_type4", False))
|
|
type4_model_idx = [model_cont_cols.index(c) for c in type4_cols if c in model_cont_cols]
|
|
trend_mask = None
|
|
if temporal_focus_type4 and type4_model_idx:
|
|
trend_mask = torch.zeros(1, 1, len(model_cont_cols), device=device)
|
|
trend_mask[:, :, type4_model_idx] = 1.0
|
|
elif temporal_exclude_type4 and type4_model_idx:
|
|
trend_mask = torch.ones(1, 1, len(model_cont_cols), device=device)
|
|
trend_mask[:, :, type4_model_idx] = 0.0
|
|
if bool(config.get("use_temporal_stage1", False)):
|
|
temporal_model = build_temporal_model(config, model_cont_cols, temporal_cond_dim, device)
|
|
opt_temporal = torch.optim.Adam(
|
|
temporal_model.parameters(),
|
|
lr=float(config.get("temporal_lr", config["lr"])),
|
|
)
|
|
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
|
|
|
|
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
|
|
alphas = 1.0 - betas
|
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
|
|
|
os.makedirs(config["out_dir"], exist_ok=True)
|
|
out_dir = Path(safe_path(config["out_dir"]))
|
|
log_path = out_dir / "train_log.csv"
|
|
init_or_append_log(log_path, args.resume)
|
|
|
|
with open(out_dir / "config_used.json", "w", encoding="utf-8") as f:
|
|
json.dump(config, f, indent=2)
|
|
|
|
main_ckpt_path = Path(args.resume_ckpt).resolve() if args.resume_ckpt else (out_dir / "model_ckpt.pt")
|
|
temporal_ckpt_path = out_dir / "temporal_ckpt.pt"
|
|
model_path = out_dir / "model.pt"
|
|
ema_path = out_dir / "model_ema.pt"
|
|
temporal_path = out_dir / "temporal.pt"
|
|
|
|
temporal_start_epoch = 0
|
|
temporal_start_step = 0
|
|
temporal_total_step = 0
|
|
main_start_epoch = 0
|
|
main_start_step = 0
|
|
total_step = 0
|
|
temporal_done = temporal_model is None
|
|
|
|
if args.resume:
|
|
if main_ckpt_path.exists():
|
|
ckpt = load_torch_state(str(main_ckpt_path), device)
|
|
model.load_state_dict(ckpt["model"])
|
|
opt.load_state_dict(ckpt["optim"])
|
|
total_step = int(ckpt.get("step", 0))
|
|
main_start_epoch = int(ckpt.get("epoch", 0))
|
|
main_start_step = int(ckpt.get("step_in_epoch", 0))
|
|
temporal_done = bool(ckpt.get("temporal_done", temporal_done))
|
|
temporal_total_step = int(ckpt.get("temporal_step", 0))
|
|
if ema is not None and ckpt.get("ema") is not None:
|
|
ema.shadow = ckpt["ema"]
|
|
if temporal_model is not None and ckpt.get("temporal") is not None:
|
|
temporal_model.load_state_dict(ckpt["temporal"])
|
|
if opt_temporal is not None and ckpt.get("temporal_optim") is not None:
|
|
opt_temporal.load_state_dict(ckpt["temporal_optim"])
|
|
print(f"resumed_main_ckpt epoch={main_start_epoch} step={main_start_step} total_step={total_step}")
|
|
elif temporal_ckpt_path.exists() and temporal_model is not None and opt_temporal is not None:
|
|
tckpt = load_torch_state(str(temporal_ckpt_path), device)
|
|
temporal_model.load_state_dict(tckpt["temporal"])
|
|
opt_temporal.load_state_dict(tckpt["temporal_optim"])
|
|
temporal_start_epoch = int(tckpt.get("epoch", 0))
|
|
temporal_start_step = int(tckpt.get("step_in_epoch", 0))
|
|
temporal_total_step = int(tckpt.get("temporal_step", 0))
|
|
print(
|
|
f"resumed_temporal_ckpt epoch={temporal_start_epoch} "
|
|
f"step={temporal_start_step} temporal_step={temporal_total_step}"
|
|
)
|
|
elif temporal_path.exists() and temporal_model is not None:
|
|
temporal_model.load_state_dict(load_torch_state(str(temporal_path), device))
|
|
temporal_done = True
|
|
print("reused_completed_temporal_stage", str(temporal_path))
|
|
|
|
if temporal_model is not None and opt_temporal is not None and not temporal_done:
|
|
for epoch in range(temporal_start_epoch, int(config.get("temporal_epochs", 1))):
|
|
skip_until = temporal_start_step if epoch == temporal_start_epoch else 0
|
|
for step, batch in enumerate(
|
|
windowed_batches(
|
|
data_paths,
|
|
cont_cols,
|
|
disc_cols,
|
|
vocab,
|
|
mean,
|
|
std,
|
|
batch_size=int(config["batch_size"]),
|
|
seq_len=int(config["seq_len"]),
|
|
max_batches=int(config["max_batches"]),
|
|
return_file_id=False,
|
|
transforms=transforms,
|
|
quantile_probs=quantile_probs,
|
|
quantile_values=quantile_values,
|
|
use_quantile=use_quantile,
|
|
shuffle_buffer=int(config.get("shuffle_buffer", 0)),
|
|
)
|
|
):
|
|
if step < skip_until:
|
|
continue
|
|
x_cont, _ = batch
|
|
x_cont = x_cont.to(device)
|
|
model_idx = [cont_cols.index(c) for c in model_cont_cols]
|
|
x_cont_model = x_cont[:, :, model_idx]
|
|
cond_cont = None
|
|
if temporal_cond_dim > 0:
|
|
cond_idx = [cont_cols.index(c) for c in type1_cols]
|
|
cond_cont = x_cont[:, :, cond_idx]
|
|
_, pred_next = temporal_model.forward_teacher(x_cont_model, cond_cont=cond_cont)
|
|
target_next = x_cont_model[:, 1:, :]
|
|
if trend_mask is not None:
|
|
mask = trend_mask.to(dtype=pred_next.dtype, device=pred_next.device)
|
|
mse = (pred_next - target_next) ** 2
|
|
temporal_loss = (mse * mask).sum() / torch.clamp(mask.sum() * mse.size(0) * mse.size(1), min=1.0)
|
|
else:
|
|
temporal_loss = F.mse_loss(pred_next, target_next)
|
|
opt_temporal.zero_grad()
|
|
temporal_loss.backward()
|
|
if float(config.get("grad_clip", 0.0)) > 0:
|
|
torch.nn.utils.clip_grad_norm_(temporal_model.parameters(), float(config["grad_clip"]))
|
|
opt_temporal.step()
|
|
temporal_total_step += 1
|
|
if step % int(config["log_every"]) == 0:
|
|
print("temporal_epoch", epoch, "step", step, "loss", float(temporal_loss))
|
|
if temporal_total_step % int(config["ckpt_every"]) == 0:
|
|
atomic_torch_save(
|
|
{
|
|
"temporal": temporal_model.state_dict(),
|
|
"temporal_optim": opt_temporal.state_dict(),
|
|
"epoch": epoch,
|
|
"step_in_epoch": step + 1,
|
|
"temporal_step": temporal_total_step,
|
|
"config": config,
|
|
},
|
|
temporal_ckpt_path,
|
|
)
|
|
temporal_start_step = 0
|
|
atomic_torch_save(
|
|
{
|
|
"temporal": temporal_model.state_dict(),
|
|
"temporal_optim": opt_temporal.state_dict(),
|
|
"epoch": epoch + 1,
|
|
"step_in_epoch": 0,
|
|
"temporal_step": temporal_total_step,
|
|
"config": config,
|
|
},
|
|
temporal_ckpt_path,
|
|
)
|
|
atomic_torch_save(temporal_model.state_dict(), temporal_path)
|
|
temporal_done = True
|
|
|
|
if bool(args.temporal_only):
|
|
return
|
|
|
|
for epoch in range(main_start_epoch, int(config["epochs"])):
|
|
skip_until = main_start_step if epoch == main_start_epoch else 0
|
|
for step, batch in enumerate(
|
|
windowed_batches(
|
|
data_paths,
|
|
cont_cols,
|
|
disc_cols,
|
|
vocab,
|
|
mean,
|
|
std,
|
|
batch_size=int(config["batch_size"]),
|
|
seq_len=int(config["seq_len"]),
|
|
max_batches=int(config["max_batches"]),
|
|
return_file_id=use_condition,
|
|
transforms=transforms,
|
|
quantile_probs=quantile_probs,
|
|
quantile_values=quantile_values,
|
|
use_quantile=use_quantile,
|
|
shuffle_buffer=int(config.get("shuffle_buffer", 0)),
|
|
)
|
|
):
|
|
if step < skip_until:
|
|
continue
|
|
if use_condition:
|
|
x_cont, x_disc, cond = batch
|
|
cond = cond.to(device)
|
|
else:
|
|
x_cont, x_disc = batch
|
|
cond = None
|
|
x_cont = x_cont.to(device)
|
|
x_disc = x_disc.to(device)
|
|
|
|
model_idx = [cont_cols.index(c) for c in model_cont_cols]
|
|
cond_idx = [cont_cols.index(c) for c in type1_cols] if type1_cols else []
|
|
x_cont_model = x_cont[:, :, model_idx]
|
|
cond_cont = x_cont[:, :, cond_idx] if cond_idx else None
|
|
|
|
trend = None
|
|
if temporal_model is not None:
|
|
temporal_model.eval()
|
|
with torch.no_grad():
|
|
trend, _ = temporal_model.forward_teacher(x_cont_model, cond_cont=cond_cont)
|
|
if trend_mask is not None and trend is not None:
|
|
trend = trend * trend_mask.to(dtype=trend.dtype, device=trend.device)
|
|
x_cont_resid = x_cont_model if trend is None else x_cont_model - trend
|
|
|
|
bsz = x_cont.size(0)
|
|
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
|
|
|
|
x_cont_t, noise = q_sample_continuous(x_cont_resid, t, alphas_cumprod)
|
|
|
|
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
|
x_disc_t, mask = q_sample_discrete(
|
|
x_disc,
|
|
t,
|
|
mask_tokens,
|
|
int(config["timesteps"]),
|
|
mask_scale=float(config.get("disc_mask_scale", 1.0)),
|
|
)
|
|
|
|
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond, cond_cont=cond_cont)
|
|
|
|
cont_target = str(config.get("cont_target", "eps"))
|
|
if cont_target == "x0":
|
|
x0_target = x_cont_resid
|
|
if float(config.get("cont_clamp_x0", 0.0)) > 0:
|
|
x0_target = torch.clamp(
|
|
x0_target,
|
|
-float(config["cont_clamp_x0"]),
|
|
float(config["cont_clamp_x0"]),
|
|
)
|
|
loss_base = (eps_pred - x0_target) ** 2
|
|
else:
|
|
loss_base = (eps_pred - noise) ** 2
|
|
|
|
if config.get("cont_loss_weighting") == "inv_std":
|
|
weights = torch.tensor(
|
|
[1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in model_cont_cols],
|
|
device=device,
|
|
dtype=eps_pred.dtype,
|
|
).view(1, 1, -1)
|
|
loss_cont = (loss_base * weights).mean()
|
|
else:
|
|
loss_cont = loss_base.mean()
|
|
|
|
if bool(config.get("snr_weighted_loss", False)):
|
|
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
|
|
snr = a_bar_t / torch.clamp(1.0 - a_bar_t, min=1e-8)
|
|
gamma = float(config.get("snr_gamma", 1.0))
|
|
snr_weight = snr / (snr + gamma)
|
|
loss_cont = (loss_cont * snr_weight.mean()).mean()
|
|
loss_disc = 0.0
|
|
loss_disc_count = 0
|
|
for i, logit in enumerate(logits):
|
|
if mask[:, :, i].any():
|
|
loss_disc = loss_disc + F.cross_entropy(
|
|
logit[mask[:, :, i]],
|
|
x_disc[:, :, i][mask[:, :, i]],
|
|
)
|
|
loss_disc_count += 1
|
|
if loss_disc_count > 0:
|
|
loss_disc = loss_disc / loss_disc_count
|
|
|
|
lam = float(config["lambda"])
|
|
loss = lam * loss_cont + (1 - lam) * loss_disc
|
|
|
|
q_weight = float(config.get("quantile_loss_weight", 0.0))
|
|
if q_weight > 0:
|
|
q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95])
|
|
q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype)
|
|
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
|
|
x_real = x_cont_resid
|
|
if cont_target == "x0":
|
|
x_gen = eps_pred
|
|
else:
|
|
x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t)
|
|
x_real = x_real.view(-1, x_real.size(-1))
|
|
x_gen = x_gen.view(-1, x_gen.size(-1))
|
|
q_real = torch.quantile(x_real, q_tensor, dim=0)
|
|
q_gen = torch.quantile(x_gen, q_tensor, dim=0)
|
|
quantile_loss = torch.mean(torch.abs(q_gen - q_real))
|
|
loss = loss + q_weight * quantile_loss
|
|
|
|
stat_weight = float(config.get("residual_stat_weight", 0.0))
|
|
if stat_weight > 0:
|
|
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
|
|
if cont_target == "x0":
|
|
x_gen = eps_pred
|
|
else:
|
|
x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t)
|
|
x_real = x_cont_resid
|
|
mean_real = x_real.mean(dim=(0, 1))
|
|
mean_gen = x_gen.mean(dim=(0, 1))
|
|
std_real = x_real.std(dim=(0, 1))
|
|
std_gen = x_gen.std(dim=(0, 1))
|
|
stat_loss = F.mse_loss(mean_gen, mean_real) + F.mse_loss(std_gen, std_real)
|
|
loss = loss + stat_weight * stat_loss
|
|
opt.zero_grad()
|
|
loss.backward()
|
|
if float(config.get("grad_clip", 0.0)) > 0:
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"]))
|
|
opt.step()
|
|
if ema is not None:
|
|
ema.update(model)
|
|
|
|
if step % int(config["log_every"]) == 0:
|
|
print("epoch", epoch, "step", step, "loss", float(loss))
|
|
with open(log_path, "a", encoding="utf-8") as f:
|
|
f.write(
|
|
"%d,%d,%.6f,%.6f,%.6f\n"
|
|
% (epoch, step, float(loss), float(loss_cont), float(loss_disc))
|
|
)
|
|
|
|
total_step += 1
|
|
if total_step % int(config["ckpt_every"]) == 0:
|
|
ckpt = {
|
|
"model": model.state_dict(),
|
|
"optim": opt.state_dict(),
|
|
"config": config,
|
|
"step": total_step,
|
|
"epoch": epoch,
|
|
"step_in_epoch": step + 1,
|
|
"temporal_done": temporal_done,
|
|
"temporal_step": temporal_total_step,
|
|
}
|
|
if ema is not None:
|
|
ckpt["ema"] = ema.state_dict()
|
|
if temporal_model is not None:
|
|
ckpt["temporal"] = temporal_model.state_dict()
|
|
if opt_temporal is not None:
|
|
ckpt["temporal_optim"] = opt_temporal.state_dict()
|
|
atomic_torch_save(ckpt, main_ckpt_path)
|
|
|
|
main_start_step = 0
|
|
ckpt = {
|
|
"model": model.state_dict(),
|
|
"optim": opt.state_dict(),
|
|
"config": config,
|
|
"step": total_step,
|
|
"epoch": epoch + 1,
|
|
"step_in_epoch": 0,
|
|
"temporal_done": temporal_done,
|
|
"temporal_step": temporal_total_step,
|
|
}
|
|
if ema is not None:
|
|
ckpt["ema"] = ema.state_dict()
|
|
if temporal_model is not None:
|
|
ckpt["temporal"] = temporal_model.state_dict()
|
|
if opt_temporal is not None:
|
|
ckpt["temporal_optim"] = opt_temporal.state_dict()
|
|
atomic_torch_save(ckpt, main_ckpt_path)
|
|
|
|
atomic_torch_save(model.state_dict(), model_path)
|
|
if ema is not None:
|
|
atomic_torch_save(ema.state_dict(), ema_path)
|
|
if temporal_model is not None:
|
|
atomic_torch_save(temporal_model.state_dict(), temporal_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|