This commit is contained in:
MZ YANG
2026-02-10 18:24:59 +08:00
parent 7eee14ba2a
commit ccb33bf876
20 changed files with 174700 additions and 986 deletions

View File

@@ -79,6 +79,8 @@ DEFAULTS = {
"temporal_transformer_dropout": 0.1,
"temporal_epochs": 2,
"temporal_lr": 1e-3,
"temporal_focus_type4": False,
"temporal_exclude_type4": False,
"quantile_loss_weight": 0.0,
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
"snr_weighted_loss": True,
@@ -110,6 +112,7 @@ def parse_args():
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
parser.add_argument("--out-dir", default=None, help="Override output directory")
parser.add_argument("--seed", type=int, default=None, help="Override random seed")
parser.add_argument("--temporal-only", action="store_true", help="Only train temporal stage-1 and exit.")
return parser.parse_args()
@@ -178,6 +181,9 @@ def main():
config["out_dir"] = str(out_dir)
if args.seed is not None:
config["seed"] = int(args.seed)
if bool(args.temporal_only):
config["use_temporal_stage1"] = True
config["epochs"] = 0
set_seed(int(config["seed"]))
@@ -188,8 +194,10 @@ def main():
type1_cols = config.get("type1_features", []) or []
type5_cols = config.get("type5_features", []) or []
type4_cols = config.get("type4_features", []) or []
type1_cols = [c for c in type1_cols if c in cont_cols]
type5_cols = [c for c in type5_cols if c in cont_cols]
type4_cols = [c for c in type4_cols if c in cont_cols]
model_cont_cols = [c for c in cont_cols if c not in type1_cols and c not in type5_cols]
if not model_cont_cols:
raise SystemExit("model_cont_cols is empty; check type1/type5 config")
@@ -243,6 +251,18 @@ def main():
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
temporal_model = None
opt_temporal = None
temporal_use_type1_cond = bool(config.get("temporal_use_type1_cond", False))
temporal_cond_dim = len(type1_cols) if (temporal_use_type1_cond and type1_cols) else 0
temporal_focus_type4 = bool(config.get("temporal_focus_type4", False))
temporal_exclude_type4 = bool(config.get("temporal_exclude_type4", False))
type4_model_idx = [model_cont_cols.index(c) for c in type4_cols if c in model_cont_cols]
trend_mask = None
if temporal_focus_type4 and type4_model_idx:
trend_mask = torch.zeros(1, 1, len(model_cont_cols), device=device)
trend_mask[:, :, type4_model_idx] = 1.0
elif temporal_exclude_type4 and type4_model_idx:
trend_mask = torch.ones(1, 1, len(model_cont_cols), device=device)
trend_mask[:, :, type4_model_idx] = 0.0
if bool(config.get("use_temporal_stage1", False)):
temporal_backbone = str(config.get("temporal_backbone", "gru"))
if temporal_backbone == "transformer":
@@ -255,6 +275,7 @@ def main():
dropout=float(config.get("temporal_transformer_dropout", 0.1)),
pos_dim=int(config.get("temporal_pos_dim", 64)),
use_pos_embed=bool(config.get("temporal_use_pos_embed", True)),
cond_dim=temporal_cond_dim,
).to(device)
else:
temporal_model = TemporalGRUGenerator(
@@ -262,6 +283,7 @@ def main():
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
num_layers=int(config.get("temporal_num_layers", 1)),
dropout=float(config.get("temporal_dropout", 0.0)),
cond_dim=temporal_cond_dim,
).to(device)
opt_temporal = torch.optim.Adam(
temporal_model.parameters(),
@@ -306,8 +328,18 @@ def main():
x_cont = x_cont.to(device)
model_idx = [cont_cols.index(c) for c in model_cont_cols]
x_cont_model = x_cont[:, :, model_idx]
trend, pred_next = temporal_model.forward_teacher(x_cont_model)
temporal_loss = F.mse_loss(pred_next, x_cont_model[:, 1:, :])
cond_cont = None
if temporal_cond_dim > 0:
cond_idx = [cont_cols.index(c) for c in type1_cols]
cond_cont = x_cont[:, :, cond_idx]
trend, pred_next = temporal_model.forward_teacher(x_cont_model, cond_cont=cond_cont)
target_next = x_cont_model[:, 1:, :]
if trend_mask is not None:
mask = trend_mask.to(dtype=pred_next.dtype, device=pred_next.device)
mse = (pred_next - target_next) ** 2
temporal_loss = (mse * mask).sum() / torch.clamp(mask.sum() * mse.size(0) * mse.size(1), min=1.0)
else:
temporal_loss = F.mse_loss(pred_next, target_next)
opt_temporal.zero_grad()
temporal_loss.backward()
if float(config.get("grad_clip", 0.0)) > 0:
@@ -356,7 +388,9 @@ def main():
if temporal_model is not None:
temporal_model.eval()
with torch.no_grad():
trend, _ = temporal_model.forward_teacher(x_cont_model)
trend, _ = temporal_model.forward_teacher(x_cont_model, cond_cont=cond_cont)
if trend_mask is not None and trend is not None:
trend = trend * trend_mask.to(dtype=trend.dtype, device=trend.device)
x_cont_resid = x_cont_model if trend is None else x_cont_model - trend
bsz = x_cont.size(0)