Add resumable submission pipeline
This commit is contained in:
81
example/config_submission_full.json
Normal file
81
example/config_submission_full.json
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
{
|
||||||
|
"data_path": "../../dataset/hai/hai-21.03/train1.csv.gz",
|
||||||
|
"data_glob": "../../dataset/hai/hai-21.03/train*.csv.gz",
|
||||||
|
"split_path": "./feature_split.json",
|
||||||
|
"stats_path": "./results/cont_stats.json",
|
||||||
|
"vocab_path": "./results/disc_vocab.json",
|
||||||
|
"out_dir": "./results",
|
||||||
|
"device": "auto",
|
||||||
|
"timesteps": 600,
|
||||||
|
"batch_size": 12,
|
||||||
|
"seq_len": 96,
|
||||||
|
"epochs": 10,
|
||||||
|
"max_batches": 4000,
|
||||||
|
"lambda": 0.7,
|
||||||
|
"lr": 0.0005,
|
||||||
|
"seed": 1337,
|
||||||
|
"log_every": 10,
|
||||||
|
"ckpt_every": 50,
|
||||||
|
"ema_decay": 0.999,
|
||||||
|
"use_ema": true,
|
||||||
|
"clip_k": 5.0,
|
||||||
|
"grad_clip": 1.0,
|
||||||
|
"use_condition": true,
|
||||||
|
"condition_type": "file_id",
|
||||||
|
"cond_dim": 32,
|
||||||
|
"use_tanh_eps": false,
|
||||||
|
"eps_scale": 1.0,
|
||||||
|
"model_time_dim": 128,
|
||||||
|
"model_hidden_dim": 512,
|
||||||
|
"model_num_layers": 2,
|
||||||
|
"model_dropout": 0.1,
|
||||||
|
"model_ff_mult": 2,
|
||||||
|
"model_pos_dim": 64,
|
||||||
|
"model_use_pos_embed": true,
|
||||||
|
"backbone_type": "transformer",
|
||||||
|
"transformer_num_layers": 3,
|
||||||
|
"transformer_nhead": 4,
|
||||||
|
"transformer_ff_dim": 512,
|
||||||
|
"transformer_dropout": 0.1,
|
||||||
|
"disc_mask_scale": 0.9,
|
||||||
|
"cont_loss_weighting": "inv_std",
|
||||||
|
"cont_loss_eps": 1e-6,
|
||||||
|
"cont_target": "x0",
|
||||||
|
"cont_clamp_x0": 5.0,
|
||||||
|
"use_quantile_transform": true,
|
||||||
|
"quantile_bins": 1001,
|
||||||
|
"cont_bound_mode": "none",
|
||||||
|
"cont_bound_strength": 2.0,
|
||||||
|
"cont_post_calibrate": true,
|
||||||
|
"cont_post_scale": {},
|
||||||
|
"full_stats": true,
|
||||||
|
"type1_features": ["P1_B4002", "P2_MSD", "P4_HT_LD", "P1_B2004", "P1_B3004", "P1_B4022", "P1_B3005"],
|
||||||
|
"type2_features": ["P1_B4005"],
|
||||||
|
"type3_features": ["P1_PCV02Z", "P1_PCV01Z", "P1_PCV01D", "P1_FCV02Z", "P1_FCV03D", "P1_FCV03Z", "P1_LCV01D", "P1_LCV01Z"],
|
||||||
|
"type4_features": ["P1_PIT02", "P2_SIT02", "P1_FT03"],
|
||||||
|
"type5_features": ["P1_FT03Z", "P1_FT02Z"],
|
||||||
|
"type6_features": ["P4_HT_PO", "P2_24Vdc", "P2_HILout"],
|
||||||
|
"routing_type1_features": ["P1_B4022"],
|
||||||
|
"routing_type5_features": [],
|
||||||
|
"shuffle_buffer": 256,
|
||||||
|
"use_temporal_stage1": true,
|
||||||
|
"temporal_backbone": "transformer",
|
||||||
|
"temporal_hidden_dim": 256,
|
||||||
|
"temporal_num_layers": 1,
|
||||||
|
"temporal_dropout": 0.0,
|
||||||
|
"temporal_pos_dim": 64,
|
||||||
|
"temporal_use_pos_embed": true,
|
||||||
|
"temporal_transformer_num_layers": 2,
|
||||||
|
"temporal_transformer_nhead": 4,
|
||||||
|
"temporal_transformer_ff_dim": 256,
|
||||||
|
"temporal_transformer_dropout": 0.1,
|
||||||
|
"temporal_epochs": 3,
|
||||||
|
"temporal_lr": 0.001,
|
||||||
|
"quantile_loss_weight": 0.2,
|
||||||
|
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
|
||||||
|
"snr_weighted_loss": true,
|
||||||
|
"snr_gamma": 1.0,
|
||||||
|
"residual_stat_weight": 0.05,
|
||||||
|
"sample_batch_size": 4,
|
||||||
|
"sample_seq_len": 96
|
||||||
|
}
|
||||||
440
example/export_samples_resume.py
Normal file
440
example/export_samples_resume.py
Normal file
@@ -0,0 +1,440 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Sample from a trained hybrid diffusion model with routing-aware export fixes."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import gzip
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from data_utils import inverse_quantile_transform, load_split, normalize_cont, quantile_calibrate_to_real
|
||||||
|
from export_samples import build_inverse_vocab, load_stats, load_torch_state, read_header
|
||||||
|
from hybrid_diffusion import (
|
||||||
|
HybridDiffusionModel,
|
||||||
|
TemporalGRUGenerator,
|
||||||
|
TemporalTransformerGenerator,
|
||||||
|
cosine_beta_schedule,
|
||||||
|
)
|
||||||
|
from platform_utils import resolve_device, resolve_path
|
||||||
|
from submission_type_utils import (
|
||||||
|
denormalize_cont_tensor,
|
||||||
|
resolve_routing_features,
|
||||||
|
resolve_taxonomy_features,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Sample and export HAI feature sequences with routing-aware fixes.")
|
||||||
|
base_dir = Path(__file__).resolve().parent
|
||||||
|
repo_dir = base_dir.parent.parent
|
||||||
|
parser.add_argument("--data-path", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train1.csv.gz"))
|
||||||
|
parser.add_argument("--data-glob", default=str(repo_dir / "dataset" / "hai" / "hai-21.03" / "train*.csv.gz"))
|
||||||
|
parser.add_argument("--split-path", default=str(base_dir / "feature_split.json"))
|
||||||
|
parser.add_argument("--stats-path", default=str(base_dir / "results" / "cont_stats.json"))
|
||||||
|
parser.add_argument("--vocab-path", default=str(base_dir / "results" / "disc_vocab.json"))
|
||||||
|
parser.add_argument("--model-path", default=str(base_dir / "results" / "model.pt"))
|
||||||
|
parser.add_argument("--out", default=str(base_dir / "results" / "generated.csv"))
|
||||||
|
parser.add_argument("--timesteps", type=int, default=200)
|
||||||
|
parser.add_argument("--seq-len", type=int, default=64)
|
||||||
|
parser.add_argument("--batch-size", type=int, default=2)
|
||||||
|
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
|
||||||
|
parser.add_argument("--include-time", action="store_true", help="Include time column as a simple index")
|
||||||
|
parser.add_argument("--clip-k", type=float, default=5.0, help="Clip continuous values to mean±k*std")
|
||||||
|
parser.add_argument("--use-ema", action="store_true", help="Use EMA weights if available")
|
||||||
|
parser.add_argument("--config", default=None, help="Optional config_used.json to infer conditioning")
|
||||||
|
parser.add_argument("--condition-id", type=int, default=-1, help="Condition file id (0..N-1), -1=random")
|
||||||
|
parser.add_argument("--include-condition", action="store_true", help="Include condition id column in CSV")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
base_dir = Path(__file__).resolve().parent
|
||||||
|
args.data_path = str(resolve_path(base_dir, args.data_path))
|
||||||
|
args.data_glob = str(resolve_path(base_dir, args.data_glob)) if args.data_glob else ""
|
||||||
|
args.split_path = str(resolve_path(base_dir, args.split_path))
|
||||||
|
args.stats_path = str(resolve_path(base_dir, args.stats_path))
|
||||||
|
args.vocab_path = str(resolve_path(base_dir, args.vocab_path))
|
||||||
|
args.model_path = str(resolve_path(base_dir, args.model_path))
|
||||||
|
args.out = str(resolve_path(base_dir, args.out))
|
||||||
|
|
||||||
|
if not os.path.exists(args.model_path):
|
||||||
|
raise SystemExit(f"missing model file: {args.model_path}")
|
||||||
|
|
||||||
|
data_path = args.data_path
|
||||||
|
if args.data_glob:
|
||||||
|
base = Path(args.data_glob).parent
|
||||||
|
pat = Path(args.data_glob).name
|
||||||
|
matches = sorted(base.glob(pat))
|
||||||
|
if matches:
|
||||||
|
data_path = str(matches[0])
|
||||||
|
|
||||||
|
split = load_split(args.split_path)
|
||||||
|
time_col = split.get("time_column", "time")
|
||||||
|
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||||
|
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
||||||
|
|
||||||
|
stats = load_stats(args.stats_path)
|
||||||
|
mean = stats["mean"]
|
||||||
|
std = stats["std"]
|
||||||
|
vmin = stats.get("min", {})
|
||||||
|
vmax = stats.get("max", {})
|
||||||
|
int_like = stats.get("int_like", {})
|
||||||
|
max_decimals = stats.get("max_decimals", {})
|
||||||
|
transforms = stats.get("transform", {})
|
||||||
|
quantile_probs = stats.get("quantile_probs")
|
||||||
|
quantile_values = stats.get("quantile_values")
|
||||||
|
quantile_raw_values = stats.get("quantile_raw_values")
|
||||||
|
|
||||||
|
vocab_json = json.load(open(args.vocab_path, "r", encoding="utf-8"))
|
||||||
|
vocab = vocab_json["vocab"]
|
||||||
|
top_token = vocab_json.get("top_token", {})
|
||||||
|
inv_vocab = build_inverse_vocab(vocab)
|
||||||
|
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||||
|
|
||||||
|
device = resolve_device(args.device)
|
||||||
|
cfg = {}
|
||||||
|
use_condition = False
|
||||||
|
cond_vocab_size = 0
|
||||||
|
if args.config:
|
||||||
|
args.config = str(resolve_path(base_dir, args.config))
|
||||||
|
if args.config and os.path.exists(args.config):
|
||||||
|
with open(args.config, "r", encoding="utf-8") as f:
|
||||||
|
cfg = json.load(f)
|
||||||
|
use_condition = bool(cfg.get("use_condition")) and cfg.get("condition_type") == "file_id"
|
||||||
|
if use_condition:
|
||||||
|
cfg_base = Path(args.config).resolve().parent
|
||||||
|
cfg_glob = cfg.get("data_glob", args.data_glob)
|
||||||
|
cfg_glob = str(resolve_path(cfg_base, cfg_glob))
|
||||||
|
base = Path(cfg_glob).parent
|
||||||
|
pat = Path(cfg_glob).name
|
||||||
|
cond_vocab_size = len(sorted(base.glob(pat)))
|
||||||
|
if cond_vocab_size <= 0:
|
||||||
|
raise SystemExit(f"use_condition enabled but no files matched data_glob: {cfg_glob}")
|
||||||
|
|
||||||
|
cont_target = str(cfg.get("cont_target", "eps"))
|
||||||
|
cont_clamp_x0 = float(cfg.get("cont_clamp_x0", 0.0))
|
||||||
|
use_quantile = bool(cfg.get("use_quantile_transform", False))
|
||||||
|
cont_bound_mode = str(cfg.get("cont_bound_mode", "clamp"))
|
||||||
|
cont_bound_strength = float(cfg.get("cont_bound_strength", 1.0))
|
||||||
|
cont_post_scale = cfg.get("cont_post_scale", {}) if isinstance(cfg.get("cont_post_scale", {}), dict) else {}
|
||||||
|
cont_post_calibrate = bool(cfg.get("cont_post_calibrate", False))
|
||||||
|
|
||||||
|
route_type1_cols = resolve_routing_features(cfg, cont_cols, "type1_features")
|
||||||
|
route_type5_cols = resolve_routing_features(cfg, cont_cols, "type5_features")
|
||||||
|
type4_cols = resolve_taxonomy_features(cfg, cont_cols, "type4_features")
|
||||||
|
model_cont_cols = [c for c in cont_cols if c not in route_type1_cols and c not in route_type5_cols]
|
||||||
|
|
||||||
|
use_temporal_stage1 = bool(cfg.get("use_temporal_stage1", False))
|
||||||
|
temporal_use_type1_cond = bool(cfg.get("temporal_use_type1_cond", False))
|
||||||
|
temporal_focus_type4 = bool(cfg.get("temporal_focus_type4", False))
|
||||||
|
temporal_exclude_type4 = bool(cfg.get("temporal_exclude_type4", False))
|
||||||
|
temporal_backbone = str(cfg.get("temporal_backbone", "gru"))
|
||||||
|
temporal_hidden_dim = int(cfg.get("temporal_hidden_dim", 256))
|
||||||
|
temporal_num_layers = int(cfg.get("temporal_num_layers", 1))
|
||||||
|
temporal_dropout = float(cfg.get("temporal_dropout", 0.0))
|
||||||
|
temporal_pos_dim = int(cfg.get("temporal_pos_dim", 64))
|
||||||
|
temporal_use_pos_embed = bool(cfg.get("temporal_use_pos_embed", True))
|
||||||
|
temporal_transformer_num_layers = int(cfg.get("temporal_transformer_num_layers", 2))
|
||||||
|
temporal_transformer_nhead = int(cfg.get("temporal_transformer_nhead", 4))
|
||||||
|
temporal_transformer_ff_dim = int(cfg.get("temporal_transformer_ff_dim", 512))
|
||||||
|
temporal_transformer_dropout = float(cfg.get("temporal_transformer_dropout", 0.1))
|
||||||
|
backbone_type = str(cfg.get("backbone_type", "gru"))
|
||||||
|
transformer_num_layers = int(cfg.get("transformer_num_layers", 2))
|
||||||
|
transformer_nhead = int(cfg.get("transformer_nhead", 4))
|
||||||
|
transformer_ff_dim = int(cfg.get("transformer_ff_dim", 512))
|
||||||
|
transformer_dropout = float(cfg.get("transformer_dropout", 0.1))
|
||||||
|
|
||||||
|
model = HybridDiffusionModel(
|
||||||
|
cont_dim=len(model_cont_cols),
|
||||||
|
disc_vocab_sizes=vocab_sizes,
|
||||||
|
time_dim=int(cfg.get("model_time_dim", 64)),
|
||||||
|
hidden_dim=int(cfg.get("model_hidden_dim", 256)),
|
||||||
|
num_layers=int(cfg.get("model_num_layers", 1)),
|
||||||
|
dropout=float(cfg.get("model_dropout", 0.0)),
|
||||||
|
ff_mult=int(cfg.get("model_ff_mult", 2)),
|
||||||
|
pos_dim=int(cfg.get("model_pos_dim", 64)),
|
||||||
|
use_pos_embed=bool(cfg.get("model_use_pos_embed", True)),
|
||||||
|
backbone_type=backbone_type,
|
||||||
|
transformer_num_layers=transformer_num_layers,
|
||||||
|
transformer_nhead=transformer_nhead,
|
||||||
|
transformer_ff_dim=transformer_ff_dim,
|
||||||
|
transformer_dropout=transformer_dropout,
|
||||||
|
cond_cont_dim=len(route_type1_cols),
|
||||||
|
cond_vocab_size=cond_vocab_size if use_condition else 0,
|
||||||
|
cond_dim=int(cfg.get("cond_dim", 32)),
|
||||||
|
use_tanh_eps=bool(cfg.get("use_tanh_eps", False)),
|
||||||
|
eps_scale=float(cfg.get("eps_scale", 1.0)),
|
||||||
|
).to(device)
|
||||||
|
if args.use_ema and os.path.exists(args.model_path.replace("model.pt", "model_ema.pt")):
|
||||||
|
ema_path = args.model_path.replace("model.pt", "model_ema.pt")
|
||||||
|
model.load_state_dict(load_torch_state(ema_path, device))
|
||||||
|
else:
|
||||||
|
model.load_state_dict(load_torch_state(args.model_path, device))
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
temporal_model = None
|
||||||
|
if use_temporal_stage1:
|
||||||
|
temporal_path = Path(args.model_path).with_name("temporal.pt")
|
||||||
|
if not temporal_path.exists():
|
||||||
|
raise SystemExit(f"missing temporal model file: {temporal_path}")
|
||||||
|
temporal_state = load_torch_state(str(temporal_path), device)
|
||||||
|
temporal_cond_dim = len(route_type1_cols) if (temporal_use_type1_cond and route_type1_cols) else 0
|
||||||
|
if isinstance(temporal_state, dict):
|
||||||
|
if "in_proj.weight" in temporal_state:
|
||||||
|
try:
|
||||||
|
temporal_cond_dim = max(0, int(temporal_state["in_proj.weight"].shape[1]) - len(model_cont_cols))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
elif "gru.weight_ih_l0" in temporal_state:
|
||||||
|
try:
|
||||||
|
temporal_cond_dim = max(0, int(temporal_state["gru.weight_ih_l0"].shape[1]) - len(model_cont_cols))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if temporal_backbone == "transformer":
|
||||||
|
temporal_model = TemporalTransformerGenerator(
|
||||||
|
input_dim=len(model_cont_cols),
|
||||||
|
hidden_dim=temporal_hidden_dim,
|
||||||
|
num_layers=temporal_transformer_num_layers,
|
||||||
|
nhead=temporal_transformer_nhead,
|
||||||
|
ff_dim=temporal_transformer_ff_dim,
|
||||||
|
dropout=temporal_transformer_dropout,
|
||||||
|
pos_dim=temporal_pos_dim,
|
||||||
|
use_pos_embed=temporal_use_pos_embed,
|
||||||
|
cond_dim=temporal_cond_dim,
|
||||||
|
).to(device)
|
||||||
|
else:
|
||||||
|
temporal_model = TemporalGRUGenerator(
|
||||||
|
input_dim=len(model_cont_cols),
|
||||||
|
hidden_dim=temporal_hidden_dim,
|
||||||
|
num_layers=temporal_num_layers,
|
||||||
|
dropout=temporal_dropout,
|
||||||
|
cond_dim=temporal_cond_dim,
|
||||||
|
).to(device)
|
||||||
|
temporal_model.load_state_dict(temporal_state)
|
||||||
|
temporal_model.eval()
|
||||||
|
|
||||||
|
betas = cosine_beta_schedule(args.timesteps).to(device)
|
||||||
|
alphas = 1.0 - betas
|
||||||
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||||
|
|
||||||
|
x_cont = torch.randn(args.batch_size, args.seq_len, len(model_cont_cols), device=device)
|
||||||
|
x_disc = torch.full((args.batch_size, args.seq_len, len(disc_cols)), 0, device=device, dtype=torch.long)
|
||||||
|
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
||||||
|
for i in range(len(disc_cols)):
|
||||||
|
x_disc[:, :, i] = mask_tokens[i]
|
||||||
|
|
||||||
|
cond = None
|
||||||
|
if use_condition:
|
||||||
|
if cond_vocab_size <= 0:
|
||||||
|
raise SystemExit("use_condition enabled but no files matched data_glob")
|
||||||
|
if args.condition_id < 0:
|
||||||
|
cond_id = torch.randint(0, cond_vocab_size, (args.batch_size,), device=device)
|
||||||
|
else:
|
||||||
|
cond_id = torch.full((args.batch_size,), int(args.condition_id), device=device, dtype=torch.long)
|
||||||
|
cond = cond_id
|
||||||
|
|
||||||
|
cond_cont = None
|
||||||
|
if route_type1_cols:
|
||||||
|
ref_glob = cfg.get("data_glob") or args.data_glob
|
||||||
|
if ref_glob:
|
||||||
|
ref_glob = str(resolve_path(Path(args.config).parent, ref_glob)) if args.config else ref_glob
|
||||||
|
base = Path(ref_glob).parent
|
||||||
|
pat = Path(ref_glob).name
|
||||||
|
refs = sorted(base.glob(pat))
|
||||||
|
if refs:
|
||||||
|
ref_path = refs[0]
|
||||||
|
ref_rows = []
|
||||||
|
with gzip.open(ref_path, "rt", newline="") as fh:
|
||||||
|
reader = csv.DictReader(fh)
|
||||||
|
for row in reader:
|
||||||
|
ref_rows.append(row)
|
||||||
|
if len(ref_rows) >= args.seq_len:
|
||||||
|
seq = ref_rows[: args.seq_len]
|
||||||
|
cond_cont = torch.zeros(args.batch_size, args.seq_len, len(route_type1_cols), device=device)
|
||||||
|
for t, row in enumerate(seq):
|
||||||
|
for i, c in enumerate(route_type1_cols):
|
||||||
|
cond_cont[:, t, i] = float(row[c])
|
||||||
|
cond_cont = normalize_cont(
|
||||||
|
cond_cont,
|
||||||
|
route_type1_cols,
|
||||||
|
mean,
|
||||||
|
std,
|
||||||
|
transforms=transforms,
|
||||||
|
quantile_probs=quantile_probs,
|
||||||
|
quantile_values=quantile_values,
|
||||||
|
use_quantile=use_quantile,
|
||||||
|
)
|
||||||
|
|
||||||
|
trend = None
|
||||||
|
if temporal_model is not None:
|
||||||
|
trend = temporal_model.generate(args.batch_size, args.seq_len, device, cond_cont=cond_cont)
|
||||||
|
if temporal_focus_type4 and type4_cols:
|
||||||
|
type4_model_idx = [model_cont_cols.index(c) for c in type4_cols if c in model_cont_cols]
|
||||||
|
if type4_model_idx:
|
||||||
|
trend_mask = torch.zeros(1, 1, len(model_cont_cols), device=device, dtype=trend.dtype)
|
||||||
|
trend_mask[:, :, type4_model_idx] = 1.0
|
||||||
|
trend = trend * trend_mask
|
||||||
|
elif temporal_exclude_type4 and type4_cols:
|
||||||
|
type4_model_idx = [model_cont_cols.index(c) for c in type4_cols if c in model_cont_cols]
|
||||||
|
if type4_model_idx:
|
||||||
|
trend_mask = torch.ones(1, 1, len(model_cont_cols), device=device, dtype=trend.dtype)
|
||||||
|
trend_mask[:, :, type4_model_idx] = 0.0
|
||||||
|
trend = trend * trend_mask
|
||||||
|
|
||||||
|
for t in reversed(range(args.timesteps)):
|
||||||
|
t_batch = torch.full((args.batch_size,), t, device=device, dtype=torch.long)
|
||||||
|
eps_pred, logits = model(x_cont, x_disc, t_batch, cond, cond_cont=cond_cont)
|
||||||
|
|
||||||
|
a_t = alphas[t]
|
||||||
|
a_bar_t = alphas_cumprod[t]
|
||||||
|
|
||||||
|
if cont_target == "x0":
|
||||||
|
x0_pred = eps_pred
|
||||||
|
if cont_clamp_x0 > 0:
|
||||||
|
x0_pred = torch.clamp(x0_pred, -cont_clamp_x0, cont_clamp_x0)
|
||||||
|
eps_pred = (x_cont - torch.sqrt(a_bar_t) * x0_pred) / torch.sqrt(1.0 - a_bar_t)
|
||||||
|
coef1 = 1.0 / torch.sqrt(a_t)
|
||||||
|
coef2 = (1 - a_t) / torch.sqrt(1 - a_bar_t)
|
||||||
|
mean_x = coef1 * (x_cont - coef2 * eps_pred)
|
||||||
|
if t > 0:
|
||||||
|
noise = torch.randn_like(x_cont)
|
||||||
|
x_cont = mean_x + torch.sqrt(betas[t]) * noise
|
||||||
|
else:
|
||||||
|
x_cont = mean_x
|
||||||
|
if args.clip_k > 0:
|
||||||
|
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
|
||||||
|
|
||||||
|
for i, logit in enumerate(logits):
|
||||||
|
if t == 0:
|
||||||
|
probs = F.softmax(logit, dim=-1)
|
||||||
|
x_disc[:, :, i] = torch.argmax(probs, dim=-1)
|
||||||
|
else:
|
||||||
|
mask = x_disc[:, :, i] == mask_tokens[i]
|
||||||
|
if mask.any():
|
||||||
|
probs = F.softmax(logit, dim=-1)
|
||||||
|
sampled = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(
|
||||||
|
args.batch_size, args.seq_len
|
||||||
|
)
|
||||||
|
x_disc[:, :, i][mask] = sampled[mask]
|
||||||
|
|
||||||
|
if trend is not None:
|
||||||
|
x_cont = x_cont + trend
|
||||||
|
x_cont = x_cont.cpu()
|
||||||
|
x_disc = x_disc.cpu()
|
||||||
|
|
||||||
|
if args.clip_k > 0:
|
||||||
|
x_cont = torch.clamp(x_cont, -args.clip_k, args.clip_k)
|
||||||
|
|
||||||
|
if use_quantile:
|
||||||
|
q_vals = {c: quantile_values[c] for c in model_cont_cols}
|
||||||
|
x_cont = inverse_quantile_transform(x_cont, model_cont_cols, quantile_probs, q_vals)
|
||||||
|
else:
|
||||||
|
mean_vec = torch.tensor([mean[c] for c in model_cont_cols], dtype=x_cont.dtype)
|
||||||
|
std_vec = torch.tensor([std[c] for c in model_cont_cols], dtype=x_cont.dtype)
|
||||||
|
x_cont = x_cont * std_vec + mean_vec
|
||||||
|
for i, c in enumerate(model_cont_cols):
|
||||||
|
if transforms.get(c) == "log1p":
|
||||||
|
x_cont[:, :, i] = torch.expm1(x_cont[:, :, i])
|
||||||
|
if cont_post_calibrate and quantile_raw_values and quantile_probs:
|
||||||
|
q_raw = {c: quantile_raw_values[c] for c in model_cont_cols}
|
||||||
|
x_cont = quantile_calibrate_to_real(x_cont, model_cont_cols, quantile_probs, q_raw)
|
||||||
|
if vmin and vmax:
|
||||||
|
for i, c in enumerate(model_cont_cols):
|
||||||
|
lo = vmin.get(c, None)
|
||||||
|
hi = vmax.get(c, None)
|
||||||
|
if lo is None or hi is None:
|
||||||
|
continue
|
||||||
|
lo = float(lo)
|
||||||
|
hi = float(hi)
|
||||||
|
if cont_bound_mode == "none":
|
||||||
|
continue
|
||||||
|
if cont_bound_mode == "sigmoid":
|
||||||
|
x_cont[:, :, i] = lo + (hi - lo) * torch.sigmoid(x_cont[:, :, i])
|
||||||
|
elif cont_bound_mode == "soft_tanh":
|
||||||
|
mid = 0.5 * (lo + hi)
|
||||||
|
half = 0.5 * (hi - lo)
|
||||||
|
denom = cont_bound_strength if cont_bound_strength > 0 else 1.0
|
||||||
|
x_cont[:, :, i] = mid + half * torch.tanh(x_cont[:, :, i] / denom)
|
||||||
|
else:
|
||||||
|
x_cont[:, :, i] = torch.clamp(x_cont[:, :, i], lo, hi)
|
||||||
|
|
||||||
|
if cont_post_scale:
|
||||||
|
for i, c in enumerate(model_cont_cols):
|
||||||
|
if c in cont_post_scale:
|
||||||
|
try:
|
||||||
|
scale = float(cont_post_scale[c])
|
||||||
|
except Exception:
|
||||||
|
scale = 1.0
|
||||||
|
x_cont[:, :, i] = x_cont[:, :, i] * scale
|
||||||
|
|
||||||
|
full_cont = torch.zeros(args.batch_size, args.seq_len, len(cont_cols), dtype=x_cont.dtype)
|
||||||
|
for i, c in enumerate(model_cont_cols):
|
||||||
|
full_idx = cont_cols.index(c)
|
||||||
|
full_cont[:, :, full_idx] = x_cont[:, :, i]
|
||||||
|
if cond_cont is not None and route_type1_cols:
|
||||||
|
cond_denorm = denormalize_cont_tensor(
|
||||||
|
cond_cont.cpu(),
|
||||||
|
route_type1_cols,
|
||||||
|
mean,
|
||||||
|
std,
|
||||||
|
transforms=transforms,
|
||||||
|
quantile_probs=quantile_probs,
|
||||||
|
quantile_values=quantile_values,
|
||||||
|
use_quantile=use_quantile,
|
||||||
|
)
|
||||||
|
for i, c in enumerate(route_type1_cols):
|
||||||
|
full_idx = cont_cols.index(c)
|
||||||
|
full_cont[:, :, full_idx] = cond_denorm[:, :, i]
|
||||||
|
for c in route_type5_cols:
|
||||||
|
if c.endswith("Z"):
|
||||||
|
base = c[:-1]
|
||||||
|
if base in cont_cols:
|
||||||
|
bidx = cont_cols.index(base)
|
||||||
|
cidx = cont_cols.index(c)
|
||||||
|
full_cont[:, :, cidx] = full_cont[:, :, bidx]
|
||||||
|
|
||||||
|
header = read_header(data_path)
|
||||||
|
out_cols = [c for c in header if c != time_col or args.include_time]
|
||||||
|
if args.include_condition and use_condition:
|
||||||
|
out_cols = ["__cond_file_id"] + out_cols
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(args.out), exist_ok=True)
|
||||||
|
with open(args.out, "w", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=out_cols)
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
row_index = 0
|
||||||
|
for b in range(args.batch_size):
|
||||||
|
for t in range(args.seq_len):
|
||||||
|
row = {}
|
||||||
|
if args.include_condition and use_condition:
|
||||||
|
row["__cond_file_id"] = str(int(cond[b].item())) if cond is not None else "-1"
|
||||||
|
if args.include_time and time_col in header:
|
||||||
|
row[time_col] = str(row_index)
|
||||||
|
for i, c in enumerate(cont_cols):
|
||||||
|
val = float(full_cont[b, t, i])
|
||||||
|
if int_like.get(c, False):
|
||||||
|
row[c] = str(int(round(val)))
|
||||||
|
else:
|
||||||
|
dec = int(max_decimals.get(c, 6))
|
||||||
|
fmt = ("%%.%df" % dec) if dec > 0 else "%.0f"
|
||||||
|
row[c] = fmt % val
|
||||||
|
for i, c in enumerate(disc_cols):
|
||||||
|
tok_idx = int(x_disc[b, t, i])
|
||||||
|
tok = inv_vocab[c][tok_idx] if tok_idx < len(inv_vocab[c]) else "<UNK>"
|
||||||
|
if tok == "<UNK>" and c in top_token:
|
||||||
|
tok = top_token[c]
|
||||||
|
row[c] = tok
|
||||||
|
writer.writerow(row)
|
||||||
|
row_index += 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
19
example/run_submission_full.sh
Normal file
19
example/run_submission_full.sh
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
RUN_DIR="${RUN_DIR:-$SCRIPT_DIR/results/submission_full}"
|
||||||
|
LOG_DIR="$RUN_DIR/logs"
|
||||||
|
mkdir -p "$LOG_DIR"
|
||||||
|
|
||||||
|
STAMP="$(date '+%Y%m%d-%H%M%S')"
|
||||||
|
LOG_FILE="$LOG_DIR/pipeline-$STAMP.log"
|
||||||
|
|
||||||
|
echo "[run_submission_full] run_dir=$RUN_DIR"
|
||||||
|
echo "[run_submission_full] log_file=$LOG_FILE"
|
||||||
|
|
||||||
|
python "$SCRIPT_DIR/run_submission_resume.py" \
|
||||||
|
--config "$SCRIPT_DIR/config_submission_full.json" \
|
||||||
|
--device "${DEVICE:-cuda}" \
|
||||||
|
--run-dir "$RUN_DIR" \
|
||||||
|
"$@" 2>&1 | tee -a "$LOG_FILE"
|
||||||
399
example/run_submission_resume.py
Normal file
399
example/run_submission_resume.py
Normal file
@@ -0,0 +1,399 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""One-command full pipeline runner with safe resume and stage skipping."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from platform_utils import is_windows, safe_path
|
||||||
|
|
||||||
|
|
||||||
|
def run(cmd: List[str]) -> None:
|
||||||
|
print("running:", " ".join(cmd))
|
||||||
|
cmd = [safe_path(arg) for arg in cmd]
|
||||||
|
if is_windows():
|
||||||
|
subprocess.run(cmd, check=True, shell=False)
|
||||||
|
else:
|
||||||
|
subprocess.run(cmd, check=True)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
base_dir = Path(__file__).resolve().parent
|
||||||
|
parser = argparse.ArgumentParser(description="Run prepare -> train -> export -> eval with resume-aware staging.")
|
||||||
|
parser.add_argument("--config", default=str(base_dir / "config_submission_full.json"))
|
||||||
|
parser.add_argument("--device", default="auto")
|
||||||
|
parser.add_argument("--run-dir", default=str(base_dir / "results" / "submission_full"))
|
||||||
|
parser.add_argument("--reference", default="")
|
||||||
|
parser.add_argument("--no-resume", action="store_true", help="Do not auto-skip completed stages or resume from ckpt.")
|
||||||
|
parser.add_argument("--skip-prepare", action="store_true")
|
||||||
|
parser.add_argument("--skip-train", action="store_true")
|
||||||
|
parser.add_argument("--skip-export", action="store_true")
|
||||||
|
parser.add_argument("--skip-eval", action="store_true")
|
||||||
|
parser.add_argument("--skip-comprehensive-eval", action="store_true")
|
||||||
|
parser.add_argument("--skip-postprocess", action="store_true")
|
||||||
|
parser.add_argument("--skip-post-eval", action="store_true")
|
||||||
|
parser.add_argument("--skip-diagnostics", action="store_true")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def load_state(path: Path) -> Dict[str, str]:
|
||||||
|
if not path.exists():
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
return json.loads(path.read_text(encoding="utf-8"))
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def save_state(path: Path, state: Dict[str, str]) -> None:
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
path.write_text(json.dumps(state, indent=2, sort_keys=True), encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def stage_complete(state: Dict[str, str], stage: str, outputs: List[Path], resume: bool) -> bool:
|
||||||
|
if not resume:
|
||||||
|
return False
|
||||||
|
if outputs and all(p.exists() for p in outputs):
|
||||||
|
return True
|
||||||
|
return state.get(stage) == "done"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
base_dir = Path(__file__).resolve().parent
|
||||||
|
config_path = Path(args.config)
|
||||||
|
if not config_path.is_absolute():
|
||||||
|
config_path = (base_dir / config_path).resolve()
|
||||||
|
run_dir = Path(args.run_dir)
|
||||||
|
if not run_dir.is_absolute():
|
||||||
|
run_dir = (base_dir / run_dir).resolve()
|
||||||
|
run_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
cfg = json.loads(config_path.read_text(encoding="utf-8"))
|
||||||
|
cfg_base = config_path.parent
|
||||||
|
|
||||||
|
def abs_cfg_like(value: str) -> str:
|
||||||
|
p = Path(value)
|
||||||
|
if p.is_absolute():
|
||||||
|
return str(p)
|
||||||
|
if any(ch in value for ch in ["*", "?", "["]):
|
||||||
|
return str(cfg_base / p)
|
||||||
|
return str((cfg_base / p).resolve())
|
||||||
|
|
||||||
|
ref = args.reference or cfg.get("data_glob") or cfg.get("data_path") or ""
|
||||||
|
if ref:
|
||||||
|
ref = abs_cfg_like(str(ref))
|
||||||
|
|
||||||
|
timesteps = int(cfg.get("timesteps", 200))
|
||||||
|
seq_len = int(cfg.get("sample_seq_len", cfg.get("seq_len", 64)))
|
||||||
|
batch_size = int(cfg.get("sample_batch_size", cfg.get("batch_size", 2)))
|
||||||
|
clip_k = float(cfg.get("clip_k", 5.0))
|
||||||
|
split_path = abs_cfg_like(str(cfg.get("split_path", "./feature_split.json")))
|
||||||
|
stats_path = abs_cfg_like(str(cfg.get("stats_path", "./results/cont_stats.json")))
|
||||||
|
vocab_path = abs_cfg_like(str(cfg.get("vocab_path", "./results/disc_vocab.json")))
|
||||||
|
data_path = abs_cfg_like(str(cfg.get("data_path", ""))) if cfg.get("data_path") else ""
|
||||||
|
data_glob = abs_cfg_like(str(cfg.get("data_glob", ""))) if cfg.get("data_glob") else ""
|
||||||
|
|
||||||
|
state_path = run_dir / "pipeline_state.json"
|
||||||
|
state = load_state(state_path)
|
||||||
|
resume = not args.no_resume
|
||||||
|
cfg_for_steps = run_dir / "config_used.json"
|
||||||
|
|
||||||
|
stage_defs = []
|
||||||
|
if not args.skip_prepare:
|
||||||
|
stage_defs.append(
|
||||||
|
(
|
||||||
|
"prepare",
|
||||||
|
[Path(stats_path), Path(vocab_path)],
|
||||||
|
[sys.executable, str(base_dir / "prepare_data.py"), "--config", str(config_path)],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not args.skip_train:
|
||||||
|
train_cmd = [
|
||||||
|
sys.executable,
|
||||||
|
str(base_dir / "train_resume.py"),
|
||||||
|
"--config",
|
||||||
|
str(config_path),
|
||||||
|
"--device",
|
||||||
|
args.device,
|
||||||
|
"--out-dir",
|
||||||
|
str(run_dir),
|
||||||
|
"--seed",
|
||||||
|
str(int(cfg.get("seed", 1337))),
|
||||||
|
]
|
||||||
|
if resume:
|
||||||
|
train_cmd.append("--resume")
|
||||||
|
stage_defs.append(("train", [run_dir / "model.pt"], train_cmd))
|
||||||
|
if not args.skip_export:
|
||||||
|
stage_defs.append(
|
||||||
|
(
|
||||||
|
"export",
|
||||||
|
[run_dir / "generated.csv"],
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
str(base_dir / "export_samples_resume.py"),
|
||||||
|
"--include-time",
|
||||||
|
"--device",
|
||||||
|
args.device,
|
||||||
|
"--config",
|
||||||
|
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
|
||||||
|
"--data-path",
|
||||||
|
str(data_path),
|
||||||
|
"--data-glob",
|
||||||
|
str(data_glob),
|
||||||
|
"--split-path",
|
||||||
|
str(split_path),
|
||||||
|
"--stats-path",
|
||||||
|
str(stats_path),
|
||||||
|
"--vocab-path",
|
||||||
|
str(vocab_path),
|
||||||
|
"--model-path",
|
||||||
|
str(run_dir / "model.pt"),
|
||||||
|
"--out",
|
||||||
|
str(run_dir / "generated.csv"),
|
||||||
|
"--timesteps",
|
||||||
|
str(timesteps),
|
||||||
|
"--seq-len",
|
||||||
|
str(seq_len),
|
||||||
|
"--batch-size",
|
||||||
|
str(batch_size),
|
||||||
|
"--clip-k",
|
||||||
|
str(clip_k),
|
||||||
|
"--use-ema",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not args.skip_eval:
|
||||||
|
eval_cmd = [
|
||||||
|
sys.executable,
|
||||||
|
str(base_dir / "evaluate_generated.py"),
|
||||||
|
"--generated",
|
||||||
|
str(run_dir / "generated.csv"),
|
||||||
|
"--split",
|
||||||
|
str(split_path),
|
||||||
|
"--stats",
|
||||||
|
str(stats_path),
|
||||||
|
"--vocab",
|
||||||
|
str(vocab_path),
|
||||||
|
"--out",
|
||||||
|
str(run_dir / "eval.json"),
|
||||||
|
]
|
||||||
|
if ref:
|
||||||
|
eval_cmd += ["--reference", str(ref)]
|
||||||
|
stage_defs.append(("eval", [run_dir / "eval.json"], eval_cmd))
|
||||||
|
if not args.skip_comprehensive_eval:
|
||||||
|
stage_defs.append(
|
||||||
|
(
|
||||||
|
"comprehensive_eval",
|
||||||
|
[run_dir / "comprehensive_eval.json"],
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
str(base_dir / "evaluate_comprehensive.py"),
|
||||||
|
"--generated",
|
||||||
|
str(run_dir / "generated.csv"),
|
||||||
|
"--reference",
|
||||||
|
str(config_path),
|
||||||
|
"--config",
|
||||||
|
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
|
||||||
|
"--split",
|
||||||
|
str(split_path),
|
||||||
|
"--stats",
|
||||||
|
str(stats_path),
|
||||||
|
"--vocab",
|
||||||
|
str(vocab_path),
|
||||||
|
"--out",
|
||||||
|
str(run_dir / "comprehensive_eval.json"),
|
||||||
|
"--device",
|
||||||
|
args.device,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not args.skip_postprocess:
|
||||||
|
post_cmd = [
|
||||||
|
sys.executable,
|
||||||
|
str(base_dir / "postprocess_types.py"),
|
||||||
|
"--generated",
|
||||||
|
str(run_dir / "generated.csv"),
|
||||||
|
"--config",
|
||||||
|
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
|
||||||
|
"--out",
|
||||||
|
str(run_dir / "generated_post.csv"),
|
||||||
|
"--seed",
|
||||||
|
str(int(cfg.get("seed", 1337))),
|
||||||
|
]
|
||||||
|
if ref:
|
||||||
|
post_cmd += ["--reference", str(ref)]
|
||||||
|
stage_defs.append(("postprocess", [run_dir / "generated_post.csv"], post_cmd))
|
||||||
|
if not args.skip_post_eval:
|
||||||
|
post_eval_cmd = [
|
||||||
|
sys.executable,
|
||||||
|
str(base_dir / "evaluate_generated.py"),
|
||||||
|
"--generated",
|
||||||
|
str(run_dir / "generated_post.csv"),
|
||||||
|
"--split",
|
||||||
|
str(split_path),
|
||||||
|
"--stats",
|
||||||
|
str(stats_path),
|
||||||
|
"--vocab",
|
||||||
|
str(vocab_path),
|
||||||
|
"--out",
|
||||||
|
str(run_dir / "eval_post.json"),
|
||||||
|
]
|
||||||
|
if ref:
|
||||||
|
post_eval_cmd += ["--reference", str(ref)]
|
||||||
|
stage_defs.append(("post_eval", [run_dir / "eval_post.json"], post_eval_cmd))
|
||||||
|
if not args.skip_comprehensive_eval:
|
||||||
|
stage_defs.append(
|
||||||
|
(
|
||||||
|
"comprehensive_post_eval",
|
||||||
|
[run_dir / "comprehensive_eval_post.json"],
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
str(base_dir / "evaluate_comprehensive.py"),
|
||||||
|
"--generated",
|
||||||
|
str(run_dir / "generated_post.csv"),
|
||||||
|
"--reference",
|
||||||
|
str(config_path),
|
||||||
|
"--config",
|
||||||
|
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
|
||||||
|
"--split",
|
||||||
|
str(split_path),
|
||||||
|
"--stats",
|
||||||
|
str(stats_path),
|
||||||
|
"--vocab",
|
||||||
|
str(vocab_path),
|
||||||
|
"--out",
|
||||||
|
str(run_dir / "comprehensive_eval_post.json"),
|
||||||
|
"--device",
|
||||||
|
args.device,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not args.skip_diagnostics:
|
||||||
|
stage_defs.extend(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"filtered_metrics",
|
||||||
|
[run_dir / "filtered_metrics.json"],
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
str(base_dir / "filtered_metrics.py"),
|
||||||
|
"--eval",
|
||||||
|
str(run_dir / "eval.json"),
|
||||||
|
"--out",
|
||||||
|
str(run_dir / "filtered_metrics.json"),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"ranked_ks",
|
||||||
|
[run_dir / "ranked_ks.csv"],
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
str(base_dir / "ranked_ks.py"),
|
||||||
|
"--eval",
|
||||||
|
str(run_dir / "eval.json"),
|
||||||
|
"--out",
|
||||||
|
str(run_dir / "ranked_ks.csv"),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"program_stats",
|
||||||
|
[run_dir / "program_stats.json"],
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
str(base_dir / "program_stats.py"),
|
||||||
|
"--generated",
|
||||||
|
str(run_dir / "generated.csv"),
|
||||||
|
"--reference",
|
||||||
|
str(config_path),
|
||||||
|
"--config",
|
||||||
|
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"controller_stats",
|
||||||
|
[run_dir / "controller_stats.json"],
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
str(base_dir / "controller_stats.py"),
|
||||||
|
"--generated",
|
||||||
|
str(run_dir / "generated.csv"),
|
||||||
|
"--reference",
|
||||||
|
str(config_path),
|
||||||
|
"--config",
|
||||||
|
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"actuator_stats",
|
||||||
|
[run_dir / "actuator_stats.json"],
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
str(base_dir / "actuator_stats.py"),
|
||||||
|
"--generated",
|
||||||
|
str(run_dir / "generated.csv"),
|
||||||
|
"--reference",
|
||||||
|
str(config_path),
|
||||||
|
"--config",
|
||||||
|
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"pv_stats",
|
||||||
|
[run_dir / "pv_stats.json"],
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
str(base_dir / "pv_stats.py"),
|
||||||
|
"--generated",
|
||||||
|
str(run_dir / "generated.csv"),
|
||||||
|
"--reference",
|
||||||
|
str(config_path),
|
||||||
|
"--config",
|
||||||
|
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"aux_stats",
|
||||||
|
[run_dir / "aux_stats.json"],
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
str(base_dir / "aux_stats.py"),
|
||||||
|
"--generated",
|
||||||
|
str(run_dir / "generated.csv"),
|
||||||
|
"--reference",
|
||||||
|
str(config_path),
|
||||||
|
"--config",
|
||||||
|
str(cfg_for_steps if cfg_for_steps.exists() else config_path),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
command_log = run_dir / "run_commands.txt"
|
||||||
|
if not command_log.exists():
|
||||||
|
command_log.write_text("", encoding="utf-8")
|
||||||
|
|
||||||
|
for stage, outputs, cmd in stage_defs:
|
||||||
|
if stage_complete(state, stage, outputs, resume):
|
||||||
|
print(f"skip_stage {stage}: outputs already present")
|
||||||
|
state[stage] = "done"
|
||||||
|
save_state(state_path, state)
|
||||||
|
continue
|
||||||
|
state[stage] = "running"
|
||||||
|
save_state(state_path, state)
|
||||||
|
with command_log.open("a", encoding="utf-8") as fh:
|
||||||
|
fh.write(stage + ": " + " ".join(cmd) + "\n")
|
||||||
|
run(cmd)
|
||||||
|
state[stage] = "done"
|
||||||
|
save_state(state_path, state)
|
||||||
|
|
||||||
|
print(f"pipeline_complete run_dir={run_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
53
example/submission_type_utils.py
Normal file
53
example/submission_type_utils.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Helpers for keeping type taxonomy and routing policy separate."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from data_utils import inverse_quantile_transform
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_taxonomy_features(config: Dict, cont_cols: List[str], base_key: str) -> List[str]:
|
||||||
|
feats = config.get(base_key, []) or []
|
||||||
|
return [c for c in feats if c in cont_cols]
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_routing_features(config: Dict, cont_cols: List[str], base_key: str) -> List[str]:
|
||||||
|
feats = config.get(f"routing_{base_key}", config.get(base_key, [])) or []
|
||||||
|
return [c for c in feats if c in cont_cols]
|
||||||
|
|
||||||
|
|
||||||
|
def denormalize_cont_tensor(
|
||||||
|
x: torch.Tensor,
|
||||||
|
cont_cols: List[str],
|
||||||
|
mean: Dict[str, float],
|
||||||
|
std: Dict[str, float],
|
||||||
|
transforms: Optional[Dict[str, str]] = None,
|
||||||
|
quantile_probs: Optional[List[float]] = None,
|
||||||
|
quantile_values: Optional[Dict[str, List[float]]] = None,
|
||||||
|
use_quantile: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if x is None:
|
||||||
|
raise ValueError("x must not be None")
|
||||||
|
if not cont_cols:
|
||||||
|
return x.clone()
|
||||||
|
|
||||||
|
out = x.clone()
|
||||||
|
if use_quantile:
|
||||||
|
if not quantile_probs or not quantile_values:
|
||||||
|
raise ValueError("use_quantile=True but quantile stats are missing")
|
||||||
|
q_vals = {c: quantile_values[c] for c in cont_cols}
|
||||||
|
out = inverse_quantile_transform(out, cont_cols, quantile_probs, q_vals)
|
||||||
|
else:
|
||||||
|
mean_vec = torch.tensor([mean[c] for c in cont_cols], dtype=out.dtype, device=out.device)
|
||||||
|
std_vec = torch.tensor([std[c] for c in cont_cols], dtype=out.dtype, device=out.device)
|
||||||
|
out = out * std_vec + mean_vec
|
||||||
|
|
||||||
|
if transforms:
|
||||||
|
for i, c in enumerate(cont_cols):
|
||||||
|
if transforms.get(c) == "log1p":
|
||||||
|
out[:, :, i] = torch.expm1(out[:, :, i])
|
||||||
|
return out
|
||||||
534
example/train_resume.py
Normal file
534
example/train_resume.py
Normal file
@@ -0,0 +1,534 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Train hybrid diffusion with checkpoint resume and selective type-aware routing."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from data_utils import load_split, windowed_batches
|
||||||
|
from hybrid_diffusion import (
|
||||||
|
HybridDiffusionModel,
|
||||||
|
TemporalGRUGenerator,
|
||||||
|
TemporalTransformerGenerator,
|
||||||
|
cosine_beta_schedule,
|
||||||
|
q_sample_continuous,
|
||||||
|
q_sample_discrete,
|
||||||
|
)
|
||||||
|
from platform_utils import resolve_device, resolve_path, safe_path
|
||||||
|
from submission_type_utils import resolve_routing_features, resolve_taxonomy_features
|
||||||
|
from train import DEFAULTS, EMA, load_json, resolve_config_paths, set_seed
|
||||||
|
|
||||||
|
BASE_DIR = Path(__file__).resolve().parent
|
||||||
|
|
||||||
|
|
||||||
|
def load_torch_state(path: str, device: str):
|
||||||
|
try:
|
||||||
|
return torch.load(path, map_location=device, weights_only=True)
|
||||||
|
except TypeError:
|
||||||
|
return torch.load(path, map_location=device)
|
||||||
|
|
||||||
|
|
||||||
|
def atomic_torch_save(obj, path: Path) -> None:
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
tmp = path.with_suffix(path.suffix + ".tmp")
|
||||||
|
torch.save(obj, str(tmp))
|
||||||
|
os.replace(str(tmp), str(path))
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Train hybrid diffusion on HAI with resume support.")
|
||||||
|
parser.add_argument("--config", default=None, help="Path to JSON config.")
|
||||||
|
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
|
||||||
|
parser.add_argument("--out-dir", default=None, help="Override output directory")
|
||||||
|
parser.add_argument("--seed", type=int, default=None, help="Override random seed")
|
||||||
|
parser.add_argument("--temporal-only", action="store_true", help="Only train temporal stage-1 and exit.")
|
||||||
|
parser.add_argument("--resume", action="store_true", help="Resume from checkpoint in out-dir if present.")
|
||||||
|
parser.add_argument("--resume-ckpt", default=None, help="Optional explicit model checkpoint path.")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def build_temporal_model(config: Dict, model_cont_cols, temporal_cond_dim: int, device: str):
|
||||||
|
temporal_backbone = str(config.get("temporal_backbone", "gru"))
|
||||||
|
if temporal_backbone == "transformer":
|
||||||
|
return TemporalTransformerGenerator(
|
||||||
|
input_dim=len(model_cont_cols),
|
||||||
|
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
|
||||||
|
num_layers=int(config.get("temporal_transformer_num_layers", 2)),
|
||||||
|
nhead=int(config.get("temporal_transformer_nhead", 4)),
|
||||||
|
ff_dim=int(config.get("temporal_transformer_ff_dim", 512)),
|
||||||
|
dropout=float(config.get("temporal_transformer_dropout", 0.1)),
|
||||||
|
pos_dim=int(config.get("temporal_pos_dim", 64)),
|
||||||
|
use_pos_embed=bool(config.get("temporal_use_pos_embed", True)),
|
||||||
|
cond_dim=temporal_cond_dim,
|
||||||
|
).to(device)
|
||||||
|
return TemporalGRUGenerator(
|
||||||
|
input_dim=len(model_cont_cols),
|
||||||
|
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
|
||||||
|
num_layers=int(config.get("temporal_num_layers", 1)),
|
||||||
|
dropout=float(config.get("temporal_dropout", 0.0)),
|
||||||
|
cond_dim=temporal_cond_dim,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def init_or_append_log(log_path: Path, resume: bool) -> None:
|
||||||
|
if resume and log_path.exists():
|
||||||
|
return
|
||||||
|
with open(log_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write("epoch,step,loss,loss_cont,loss_disc\n")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
if args.config:
|
||||||
|
print("using_config", str(Path(args.config).resolve()))
|
||||||
|
config = dict(DEFAULTS)
|
||||||
|
if args.config:
|
||||||
|
cfg_path = Path(args.config).resolve()
|
||||||
|
config.update(load_json(str(cfg_path)))
|
||||||
|
config = resolve_config_paths(config, cfg_path.parent)
|
||||||
|
else:
|
||||||
|
config = resolve_config_paths(config, BASE_DIR)
|
||||||
|
|
||||||
|
if args.device != "auto":
|
||||||
|
config["device"] = args.device
|
||||||
|
if args.out_dir:
|
||||||
|
out_dir = Path(args.out_dir)
|
||||||
|
if not out_dir.is_absolute():
|
||||||
|
base = Path(args.config).resolve().parent if args.config else BASE_DIR
|
||||||
|
out_dir = resolve_path(base, out_dir)
|
||||||
|
config["out_dir"] = str(out_dir)
|
||||||
|
if args.seed is not None:
|
||||||
|
config["seed"] = int(args.seed)
|
||||||
|
if bool(args.temporal_only):
|
||||||
|
config["use_temporal_stage1"] = True
|
||||||
|
config["epochs"] = 0
|
||||||
|
|
||||||
|
set_seed(int(config["seed"]))
|
||||||
|
|
||||||
|
split = load_split(config["split_path"])
|
||||||
|
time_col = split.get("time_column", "time")
|
||||||
|
cont_cols = [c for c in split["continuous"] if c != time_col]
|
||||||
|
disc_cols = [c for c in split["discrete"] if not c.startswith("attack") and c != time_col]
|
||||||
|
|
||||||
|
type1_cols = resolve_routing_features(config, cont_cols, "type1_features")
|
||||||
|
type5_cols = resolve_routing_features(config, cont_cols, "type5_features")
|
||||||
|
type4_cols = resolve_taxonomy_features(config, cont_cols, "type4_features")
|
||||||
|
model_cont_cols = [c for c in cont_cols if c not in type1_cols and c not in type5_cols]
|
||||||
|
if not model_cont_cols:
|
||||||
|
raise SystemExit("model_cont_cols is empty; check routing_type1_features/routing_type5_features")
|
||||||
|
|
||||||
|
stats = load_json(config["stats_path"])
|
||||||
|
mean = stats["mean"]
|
||||||
|
std = stats["std"]
|
||||||
|
transforms = stats.get("transform", {})
|
||||||
|
raw_std = stats.get("raw_std", std)
|
||||||
|
quantile_probs = stats.get("quantile_probs")
|
||||||
|
quantile_values = stats.get("quantile_values")
|
||||||
|
use_quantile = bool(config.get("use_quantile_transform", False))
|
||||||
|
|
||||||
|
vocab = load_json(config["vocab_path"])["vocab"]
|
||||||
|
vocab_sizes = [len(vocab[c]) for c in disc_cols]
|
||||||
|
|
||||||
|
data_paths = None
|
||||||
|
if "data_glob" in config and config["data_glob"]:
|
||||||
|
data_paths = sorted(Path(config["data_glob"]).parent.glob(Path(config["data_glob"]).name))
|
||||||
|
if data_paths:
|
||||||
|
data_paths = [safe_path(p) for p in data_paths]
|
||||||
|
if not data_paths:
|
||||||
|
data_paths = [safe_path(config["data_path"])]
|
||||||
|
|
||||||
|
use_condition = bool(config.get("use_condition")) and config.get("condition_type") == "file_id"
|
||||||
|
cond_vocab_size = len(data_paths) if use_condition else 0
|
||||||
|
|
||||||
|
device = resolve_device(str(config["device"]))
|
||||||
|
print("device", device)
|
||||||
|
model = HybridDiffusionModel(
|
||||||
|
cont_dim=len(model_cont_cols),
|
||||||
|
disc_vocab_sizes=vocab_sizes,
|
||||||
|
time_dim=int(config.get("model_time_dim", 64)),
|
||||||
|
hidden_dim=int(config.get("model_hidden_dim", 256)),
|
||||||
|
num_layers=int(config.get("model_num_layers", 1)),
|
||||||
|
dropout=float(config.get("model_dropout", 0.0)),
|
||||||
|
ff_mult=int(config.get("model_ff_mult", 2)),
|
||||||
|
pos_dim=int(config.get("model_pos_dim", 64)),
|
||||||
|
use_pos_embed=bool(config.get("model_use_pos_embed", True)),
|
||||||
|
backbone_type=str(config.get("backbone_type", "gru")),
|
||||||
|
transformer_num_layers=int(config.get("transformer_num_layers", 4)),
|
||||||
|
transformer_nhead=int(config.get("transformer_nhead", 8)),
|
||||||
|
transformer_ff_dim=int(config.get("transformer_ff_dim", 2048)),
|
||||||
|
transformer_dropout=float(config.get("transformer_dropout", 0.1)),
|
||||||
|
cond_cont_dim=len(type1_cols),
|
||||||
|
cond_vocab_size=cond_vocab_size,
|
||||||
|
cond_dim=int(config.get("cond_dim", 32)),
|
||||||
|
use_tanh_eps=bool(config.get("use_tanh_eps", False)),
|
||||||
|
eps_scale=float(config.get("eps_scale", 1.0)),
|
||||||
|
).to(device)
|
||||||
|
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
|
||||||
|
|
||||||
|
temporal_model = None
|
||||||
|
opt_temporal = None
|
||||||
|
temporal_use_type1_cond = bool(config.get("temporal_use_type1_cond", False))
|
||||||
|
temporal_cond_dim = len(type1_cols) if (temporal_use_type1_cond and type1_cols) else 0
|
||||||
|
temporal_focus_type4 = bool(config.get("temporal_focus_type4", False))
|
||||||
|
temporal_exclude_type4 = bool(config.get("temporal_exclude_type4", False))
|
||||||
|
type4_model_idx = [model_cont_cols.index(c) for c in type4_cols if c in model_cont_cols]
|
||||||
|
trend_mask = None
|
||||||
|
if temporal_focus_type4 and type4_model_idx:
|
||||||
|
trend_mask = torch.zeros(1, 1, len(model_cont_cols), device=device)
|
||||||
|
trend_mask[:, :, type4_model_idx] = 1.0
|
||||||
|
elif temporal_exclude_type4 and type4_model_idx:
|
||||||
|
trend_mask = torch.ones(1, 1, len(model_cont_cols), device=device)
|
||||||
|
trend_mask[:, :, type4_model_idx] = 0.0
|
||||||
|
if bool(config.get("use_temporal_stage1", False)):
|
||||||
|
temporal_model = build_temporal_model(config, model_cont_cols, temporal_cond_dim, device)
|
||||||
|
opt_temporal = torch.optim.Adam(
|
||||||
|
temporal_model.parameters(),
|
||||||
|
lr=float(config.get("temporal_lr", config["lr"])),
|
||||||
|
)
|
||||||
|
ema = EMA(model, float(config["ema_decay"])) if config.get("use_ema") else None
|
||||||
|
|
||||||
|
betas = cosine_beta_schedule(int(config["timesteps"])).to(device)
|
||||||
|
alphas = 1.0 - betas
|
||||||
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||||
|
|
||||||
|
os.makedirs(config["out_dir"], exist_ok=True)
|
||||||
|
out_dir = Path(safe_path(config["out_dir"]))
|
||||||
|
log_path = out_dir / "train_log.csv"
|
||||||
|
init_or_append_log(log_path, args.resume)
|
||||||
|
|
||||||
|
with open(out_dir / "config_used.json", "w", encoding="utf-8") as f:
|
||||||
|
json.dump(config, f, indent=2)
|
||||||
|
|
||||||
|
main_ckpt_path = Path(args.resume_ckpt).resolve() if args.resume_ckpt else (out_dir / "model_ckpt.pt")
|
||||||
|
temporal_ckpt_path = out_dir / "temporal_ckpt.pt"
|
||||||
|
model_path = out_dir / "model.pt"
|
||||||
|
ema_path = out_dir / "model_ema.pt"
|
||||||
|
temporal_path = out_dir / "temporal.pt"
|
||||||
|
|
||||||
|
temporal_start_epoch = 0
|
||||||
|
temporal_start_step = 0
|
||||||
|
temporal_total_step = 0
|
||||||
|
main_start_epoch = 0
|
||||||
|
main_start_step = 0
|
||||||
|
total_step = 0
|
||||||
|
temporal_done = temporal_model is None
|
||||||
|
|
||||||
|
if args.resume:
|
||||||
|
if main_ckpt_path.exists():
|
||||||
|
ckpt = load_torch_state(str(main_ckpt_path), device)
|
||||||
|
model.load_state_dict(ckpt["model"])
|
||||||
|
opt.load_state_dict(ckpt["optim"])
|
||||||
|
total_step = int(ckpt.get("step", 0))
|
||||||
|
main_start_epoch = int(ckpt.get("epoch", 0))
|
||||||
|
main_start_step = int(ckpt.get("step_in_epoch", 0))
|
||||||
|
temporal_done = bool(ckpt.get("temporal_done", temporal_done))
|
||||||
|
temporal_total_step = int(ckpt.get("temporal_step", 0))
|
||||||
|
if ema is not None and ckpt.get("ema") is not None:
|
||||||
|
ema.shadow = ckpt["ema"]
|
||||||
|
if temporal_model is not None and ckpt.get("temporal") is not None:
|
||||||
|
temporal_model.load_state_dict(ckpt["temporal"])
|
||||||
|
if opt_temporal is not None and ckpt.get("temporal_optim") is not None:
|
||||||
|
opt_temporal.load_state_dict(ckpt["temporal_optim"])
|
||||||
|
print(f"resumed_main_ckpt epoch={main_start_epoch} step={main_start_step} total_step={total_step}")
|
||||||
|
elif temporal_ckpt_path.exists() and temporal_model is not None and opt_temporal is not None:
|
||||||
|
tckpt = load_torch_state(str(temporal_ckpt_path), device)
|
||||||
|
temporal_model.load_state_dict(tckpt["temporal"])
|
||||||
|
opt_temporal.load_state_dict(tckpt["temporal_optim"])
|
||||||
|
temporal_start_epoch = int(tckpt.get("epoch", 0))
|
||||||
|
temporal_start_step = int(tckpt.get("step_in_epoch", 0))
|
||||||
|
temporal_total_step = int(tckpt.get("temporal_step", 0))
|
||||||
|
print(
|
||||||
|
f"resumed_temporal_ckpt epoch={temporal_start_epoch} "
|
||||||
|
f"step={temporal_start_step} temporal_step={temporal_total_step}"
|
||||||
|
)
|
||||||
|
elif temporal_path.exists() and temporal_model is not None:
|
||||||
|
temporal_model.load_state_dict(load_torch_state(str(temporal_path), device))
|
||||||
|
temporal_done = True
|
||||||
|
print("reused_completed_temporal_stage", str(temporal_path))
|
||||||
|
|
||||||
|
if temporal_model is not None and opt_temporal is not None and not temporal_done:
|
||||||
|
for epoch in range(temporal_start_epoch, int(config.get("temporal_epochs", 1))):
|
||||||
|
skip_until = temporal_start_step if epoch == temporal_start_epoch else 0
|
||||||
|
for step, batch in enumerate(
|
||||||
|
windowed_batches(
|
||||||
|
data_paths,
|
||||||
|
cont_cols,
|
||||||
|
disc_cols,
|
||||||
|
vocab,
|
||||||
|
mean,
|
||||||
|
std,
|
||||||
|
batch_size=int(config["batch_size"]),
|
||||||
|
seq_len=int(config["seq_len"]),
|
||||||
|
max_batches=int(config["max_batches"]),
|
||||||
|
return_file_id=False,
|
||||||
|
transforms=transforms,
|
||||||
|
quantile_probs=quantile_probs,
|
||||||
|
quantile_values=quantile_values,
|
||||||
|
use_quantile=use_quantile,
|
||||||
|
shuffle_buffer=int(config.get("shuffle_buffer", 0)),
|
||||||
|
)
|
||||||
|
):
|
||||||
|
if step < skip_until:
|
||||||
|
continue
|
||||||
|
x_cont, _ = batch
|
||||||
|
x_cont = x_cont.to(device)
|
||||||
|
model_idx = [cont_cols.index(c) for c in model_cont_cols]
|
||||||
|
x_cont_model = x_cont[:, :, model_idx]
|
||||||
|
cond_cont = None
|
||||||
|
if temporal_cond_dim > 0:
|
||||||
|
cond_idx = [cont_cols.index(c) for c in type1_cols]
|
||||||
|
cond_cont = x_cont[:, :, cond_idx]
|
||||||
|
_, pred_next = temporal_model.forward_teacher(x_cont_model, cond_cont=cond_cont)
|
||||||
|
target_next = x_cont_model[:, 1:, :]
|
||||||
|
if trend_mask is not None:
|
||||||
|
mask = trend_mask.to(dtype=pred_next.dtype, device=pred_next.device)
|
||||||
|
mse = (pred_next - target_next) ** 2
|
||||||
|
temporal_loss = (mse * mask).sum() / torch.clamp(mask.sum() * mse.size(0) * mse.size(1), min=1.0)
|
||||||
|
else:
|
||||||
|
temporal_loss = F.mse_loss(pred_next, target_next)
|
||||||
|
opt_temporal.zero_grad()
|
||||||
|
temporal_loss.backward()
|
||||||
|
if float(config.get("grad_clip", 0.0)) > 0:
|
||||||
|
torch.nn.utils.clip_grad_norm_(temporal_model.parameters(), float(config["grad_clip"]))
|
||||||
|
opt_temporal.step()
|
||||||
|
temporal_total_step += 1
|
||||||
|
if step % int(config["log_every"]) == 0:
|
||||||
|
print("temporal_epoch", epoch, "step", step, "loss", float(temporal_loss))
|
||||||
|
if temporal_total_step % int(config["ckpt_every"]) == 0:
|
||||||
|
atomic_torch_save(
|
||||||
|
{
|
||||||
|
"temporal": temporal_model.state_dict(),
|
||||||
|
"temporal_optim": opt_temporal.state_dict(),
|
||||||
|
"epoch": epoch,
|
||||||
|
"step_in_epoch": step + 1,
|
||||||
|
"temporal_step": temporal_total_step,
|
||||||
|
"config": config,
|
||||||
|
},
|
||||||
|
temporal_ckpt_path,
|
||||||
|
)
|
||||||
|
temporal_start_step = 0
|
||||||
|
atomic_torch_save(
|
||||||
|
{
|
||||||
|
"temporal": temporal_model.state_dict(),
|
||||||
|
"temporal_optim": opt_temporal.state_dict(),
|
||||||
|
"epoch": epoch + 1,
|
||||||
|
"step_in_epoch": 0,
|
||||||
|
"temporal_step": temporal_total_step,
|
||||||
|
"config": config,
|
||||||
|
},
|
||||||
|
temporal_ckpt_path,
|
||||||
|
)
|
||||||
|
atomic_torch_save(temporal_model.state_dict(), temporal_path)
|
||||||
|
temporal_done = True
|
||||||
|
|
||||||
|
if bool(args.temporal_only):
|
||||||
|
return
|
||||||
|
|
||||||
|
for epoch in range(main_start_epoch, int(config["epochs"])):
|
||||||
|
skip_until = main_start_step if epoch == main_start_epoch else 0
|
||||||
|
for step, batch in enumerate(
|
||||||
|
windowed_batches(
|
||||||
|
data_paths,
|
||||||
|
cont_cols,
|
||||||
|
disc_cols,
|
||||||
|
vocab,
|
||||||
|
mean,
|
||||||
|
std,
|
||||||
|
batch_size=int(config["batch_size"]),
|
||||||
|
seq_len=int(config["seq_len"]),
|
||||||
|
max_batches=int(config["max_batches"]),
|
||||||
|
return_file_id=use_condition,
|
||||||
|
transforms=transforms,
|
||||||
|
quantile_probs=quantile_probs,
|
||||||
|
quantile_values=quantile_values,
|
||||||
|
use_quantile=use_quantile,
|
||||||
|
shuffle_buffer=int(config.get("shuffle_buffer", 0)),
|
||||||
|
)
|
||||||
|
):
|
||||||
|
if step < skip_until:
|
||||||
|
continue
|
||||||
|
if use_condition:
|
||||||
|
x_cont, x_disc, cond = batch
|
||||||
|
cond = cond.to(device)
|
||||||
|
else:
|
||||||
|
x_cont, x_disc = batch
|
||||||
|
cond = None
|
||||||
|
x_cont = x_cont.to(device)
|
||||||
|
x_disc = x_disc.to(device)
|
||||||
|
|
||||||
|
model_idx = [cont_cols.index(c) for c in model_cont_cols]
|
||||||
|
cond_idx = [cont_cols.index(c) for c in type1_cols] if type1_cols else []
|
||||||
|
x_cont_model = x_cont[:, :, model_idx]
|
||||||
|
cond_cont = x_cont[:, :, cond_idx] if cond_idx else None
|
||||||
|
|
||||||
|
trend = None
|
||||||
|
if temporal_model is not None:
|
||||||
|
temporal_model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
trend, _ = temporal_model.forward_teacher(x_cont_model, cond_cont=cond_cont)
|
||||||
|
if trend_mask is not None and trend is not None:
|
||||||
|
trend = trend * trend_mask.to(dtype=trend.dtype, device=trend.device)
|
||||||
|
x_cont_resid = x_cont_model if trend is None else x_cont_model - trend
|
||||||
|
|
||||||
|
bsz = x_cont.size(0)
|
||||||
|
t = torch.randint(0, int(config["timesteps"]), (bsz,), device=device)
|
||||||
|
|
||||||
|
x_cont_t, noise = q_sample_continuous(x_cont_resid, t, alphas_cumprod)
|
||||||
|
|
||||||
|
mask_tokens = torch.tensor(vocab_sizes, device=device)
|
||||||
|
x_disc_t, mask = q_sample_discrete(
|
||||||
|
x_disc,
|
||||||
|
t,
|
||||||
|
mask_tokens,
|
||||||
|
int(config["timesteps"]),
|
||||||
|
mask_scale=float(config.get("disc_mask_scale", 1.0)),
|
||||||
|
)
|
||||||
|
|
||||||
|
eps_pred, logits = model(x_cont_t, x_disc_t, t, cond, cond_cont=cond_cont)
|
||||||
|
|
||||||
|
cont_target = str(config.get("cont_target", "eps"))
|
||||||
|
if cont_target == "x0":
|
||||||
|
x0_target = x_cont_resid
|
||||||
|
if float(config.get("cont_clamp_x0", 0.0)) > 0:
|
||||||
|
x0_target = torch.clamp(
|
||||||
|
x0_target,
|
||||||
|
-float(config["cont_clamp_x0"]),
|
||||||
|
float(config["cont_clamp_x0"]),
|
||||||
|
)
|
||||||
|
loss_base = (eps_pred - x0_target) ** 2
|
||||||
|
else:
|
||||||
|
loss_base = (eps_pred - noise) ** 2
|
||||||
|
|
||||||
|
if config.get("cont_loss_weighting") == "inv_std":
|
||||||
|
weights = torch.tensor(
|
||||||
|
[1.0 / (float(raw_std[c]) ** 2 + float(config.get("cont_loss_eps", 1e-6))) for c in model_cont_cols],
|
||||||
|
device=device,
|
||||||
|
dtype=eps_pred.dtype,
|
||||||
|
).view(1, 1, -1)
|
||||||
|
loss_cont = (loss_base * weights).mean()
|
||||||
|
else:
|
||||||
|
loss_cont = loss_base.mean()
|
||||||
|
|
||||||
|
if bool(config.get("snr_weighted_loss", False)):
|
||||||
|
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
|
||||||
|
snr = a_bar_t / torch.clamp(1.0 - a_bar_t, min=1e-8)
|
||||||
|
gamma = float(config.get("snr_gamma", 1.0))
|
||||||
|
snr_weight = snr / (snr + gamma)
|
||||||
|
loss_cont = (loss_cont * snr_weight.mean()).mean()
|
||||||
|
loss_disc = 0.0
|
||||||
|
loss_disc_count = 0
|
||||||
|
for i, logit in enumerate(logits):
|
||||||
|
if mask[:, :, i].any():
|
||||||
|
loss_disc = loss_disc + F.cross_entropy(
|
||||||
|
logit[mask[:, :, i]],
|
||||||
|
x_disc[:, :, i][mask[:, :, i]],
|
||||||
|
)
|
||||||
|
loss_disc_count += 1
|
||||||
|
if loss_disc_count > 0:
|
||||||
|
loss_disc = loss_disc / loss_disc_count
|
||||||
|
|
||||||
|
lam = float(config["lambda"])
|
||||||
|
loss = lam * loss_cont + (1 - lam) * loss_disc
|
||||||
|
|
||||||
|
q_weight = float(config.get("quantile_loss_weight", 0.0))
|
||||||
|
if q_weight > 0:
|
||||||
|
q_points = config.get("quantile_points", [0.05, 0.25, 0.5, 0.75, 0.95])
|
||||||
|
q_tensor = torch.tensor(q_points, device=device, dtype=x_cont.dtype)
|
||||||
|
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
|
||||||
|
x_real = x_cont_resid
|
||||||
|
if cont_target == "x0":
|
||||||
|
x_gen = eps_pred
|
||||||
|
else:
|
||||||
|
x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t)
|
||||||
|
x_real = x_real.view(-1, x_real.size(-1))
|
||||||
|
x_gen = x_gen.view(-1, x_gen.size(-1))
|
||||||
|
q_real = torch.quantile(x_real, q_tensor, dim=0)
|
||||||
|
q_gen = torch.quantile(x_gen, q_tensor, dim=0)
|
||||||
|
quantile_loss = torch.mean(torch.abs(q_gen - q_real))
|
||||||
|
loss = loss + q_weight * quantile_loss
|
||||||
|
|
||||||
|
stat_weight = float(config.get("residual_stat_weight", 0.0))
|
||||||
|
if stat_weight > 0:
|
||||||
|
a_bar_t = alphas_cumprod[t].view(-1, 1, 1)
|
||||||
|
if cont_target == "x0":
|
||||||
|
x_gen = eps_pred
|
||||||
|
else:
|
||||||
|
x_gen = (x_cont_t - torch.sqrt(1.0 - a_bar_t) * eps_pred) / torch.sqrt(a_bar_t)
|
||||||
|
x_real = x_cont_resid
|
||||||
|
mean_real = x_real.mean(dim=(0, 1))
|
||||||
|
mean_gen = x_gen.mean(dim=(0, 1))
|
||||||
|
std_real = x_real.std(dim=(0, 1))
|
||||||
|
std_gen = x_gen.std(dim=(0, 1))
|
||||||
|
stat_loss = F.mse_loss(mean_gen, mean_real) + F.mse_loss(std_gen, std_real)
|
||||||
|
loss = loss + stat_weight * stat_loss
|
||||||
|
opt.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
if float(config.get("grad_clip", 0.0)) > 0:
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), float(config["grad_clip"]))
|
||||||
|
opt.step()
|
||||||
|
if ema is not None:
|
||||||
|
ema.update(model)
|
||||||
|
|
||||||
|
if step % int(config["log_every"]) == 0:
|
||||||
|
print("epoch", epoch, "step", step, "loss", float(loss))
|
||||||
|
with open(log_path, "a", encoding="utf-8") as f:
|
||||||
|
f.write(
|
||||||
|
"%d,%d,%.6f,%.6f,%.6f\n"
|
||||||
|
% (epoch, step, float(loss), float(loss_cont), float(loss_disc))
|
||||||
|
)
|
||||||
|
|
||||||
|
total_step += 1
|
||||||
|
if total_step % int(config["ckpt_every"]) == 0:
|
||||||
|
ckpt = {
|
||||||
|
"model": model.state_dict(),
|
||||||
|
"optim": opt.state_dict(),
|
||||||
|
"config": config,
|
||||||
|
"step": total_step,
|
||||||
|
"epoch": epoch,
|
||||||
|
"step_in_epoch": step + 1,
|
||||||
|
"temporal_done": temporal_done,
|
||||||
|
"temporal_step": temporal_total_step,
|
||||||
|
}
|
||||||
|
if ema is not None:
|
||||||
|
ckpt["ema"] = ema.state_dict()
|
||||||
|
if temporal_model is not None:
|
||||||
|
ckpt["temporal"] = temporal_model.state_dict()
|
||||||
|
if opt_temporal is not None:
|
||||||
|
ckpt["temporal_optim"] = opt_temporal.state_dict()
|
||||||
|
atomic_torch_save(ckpt, main_ckpt_path)
|
||||||
|
|
||||||
|
main_start_step = 0
|
||||||
|
ckpt = {
|
||||||
|
"model": model.state_dict(),
|
||||||
|
"optim": opt.state_dict(),
|
||||||
|
"config": config,
|
||||||
|
"step": total_step,
|
||||||
|
"epoch": epoch + 1,
|
||||||
|
"step_in_epoch": 0,
|
||||||
|
"temporal_done": temporal_done,
|
||||||
|
"temporal_step": temporal_total_step,
|
||||||
|
}
|
||||||
|
if ema is not None:
|
||||||
|
ckpt["ema"] = ema.state_dict()
|
||||||
|
if temporal_model is not None:
|
||||||
|
ckpt["temporal"] = temporal_model.state_dict()
|
||||||
|
if opt_temporal is not None:
|
||||||
|
ckpt["temporal_optim"] = opt_temporal.state_dict()
|
||||||
|
atomic_torch_save(ckpt, main_ckpt_path)
|
||||||
|
|
||||||
|
atomic_torch_save(model.state_dict(), model_path)
|
||||||
|
if ema is not None:
|
||||||
|
atomic_torch_save(ema.state_dict(), ema_path)
|
||||||
|
if temporal_model is not None:
|
||||||
|
atomic_torch_save(temporal_model.state_dict(), temporal_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user