update
This commit is contained in:
@@ -79,6 +79,8 @@ DEFAULTS = {
|
||||
"temporal_transformer_dropout": 0.1,
|
||||
"temporal_epochs": 2,
|
||||
"temporal_lr": 1e-3,
|
||||
"temporal_focus_type4": False,
|
||||
"temporal_exclude_type4": False,
|
||||
"quantile_loss_weight": 0.0,
|
||||
"quantile_points": [0.05, 0.25, 0.5, 0.75, 0.95],
|
||||
"snr_weighted_loss": True,
|
||||
@@ -110,6 +112,7 @@ def parse_args():
|
||||
parser.add_argument("--device", default="auto", help="cpu, cuda, or auto")
|
||||
parser.add_argument("--out-dir", default=None, help="Override output directory")
|
||||
parser.add_argument("--seed", type=int, default=None, help="Override random seed")
|
||||
parser.add_argument("--temporal-only", action="store_true", help="Only train temporal stage-1 and exit.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -178,6 +181,9 @@ def main():
|
||||
config["out_dir"] = str(out_dir)
|
||||
if args.seed is not None:
|
||||
config["seed"] = int(args.seed)
|
||||
if bool(args.temporal_only):
|
||||
config["use_temporal_stage1"] = True
|
||||
config["epochs"] = 0
|
||||
|
||||
set_seed(int(config["seed"]))
|
||||
|
||||
@@ -188,8 +194,10 @@ def main():
|
||||
|
||||
type1_cols = config.get("type1_features", []) or []
|
||||
type5_cols = config.get("type5_features", []) or []
|
||||
type4_cols = config.get("type4_features", []) or []
|
||||
type1_cols = [c for c in type1_cols if c in cont_cols]
|
||||
type5_cols = [c for c in type5_cols if c in cont_cols]
|
||||
type4_cols = [c for c in type4_cols if c in cont_cols]
|
||||
model_cont_cols = [c for c in cont_cols if c not in type1_cols and c not in type5_cols]
|
||||
if not model_cont_cols:
|
||||
raise SystemExit("model_cont_cols is empty; check type1/type5 config")
|
||||
@@ -243,6 +251,18 @@ def main():
|
||||
opt = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
|
||||
temporal_model = None
|
||||
opt_temporal = None
|
||||
temporal_use_type1_cond = bool(config.get("temporal_use_type1_cond", False))
|
||||
temporal_cond_dim = len(type1_cols) if (temporal_use_type1_cond and type1_cols) else 0
|
||||
temporal_focus_type4 = bool(config.get("temporal_focus_type4", False))
|
||||
temporal_exclude_type4 = bool(config.get("temporal_exclude_type4", False))
|
||||
type4_model_idx = [model_cont_cols.index(c) for c in type4_cols if c in model_cont_cols]
|
||||
trend_mask = None
|
||||
if temporal_focus_type4 and type4_model_idx:
|
||||
trend_mask = torch.zeros(1, 1, len(model_cont_cols), device=device)
|
||||
trend_mask[:, :, type4_model_idx] = 1.0
|
||||
elif temporal_exclude_type4 and type4_model_idx:
|
||||
trend_mask = torch.ones(1, 1, len(model_cont_cols), device=device)
|
||||
trend_mask[:, :, type4_model_idx] = 0.0
|
||||
if bool(config.get("use_temporal_stage1", False)):
|
||||
temporal_backbone = str(config.get("temporal_backbone", "gru"))
|
||||
if temporal_backbone == "transformer":
|
||||
@@ -255,6 +275,7 @@ def main():
|
||||
dropout=float(config.get("temporal_transformer_dropout", 0.1)),
|
||||
pos_dim=int(config.get("temporal_pos_dim", 64)),
|
||||
use_pos_embed=bool(config.get("temporal_use_pos_embed", True)),
|
||||
cond_dim=temporal_cond_dim,
|
||||
).to(device)
|
||||
else:
|
||||
temporal_model = TemporalGRUGenerator(
|
||||
@@ -262,6 +283,7 @@ def main():
|
||||
hidden_dim=int(config.get("temporal_hidden_dim", 256)),
|
||||
num_layers=int(config.get("temporal_num_layers", 1)),
|
||||
dropout=float(config.get("temporal_dropout", 0.0)),
|
||||
cond_dim=temporal_cond_dim,
|
||||
).to(device)
|
||||
opt_temporal = torch.optim.Adam(
|
||||
temporal_model.parameters(),
|
||||
@@ -306,8 +328,18 @@ def main():
|
||||
x_cont = x_cont.to(device)
|
||||
model_idx = [cont_cols.index(c) for c in model_cont_cols]
|
||||
x_cont_model = x_cont[:, :, model_idx]
|
||||
trend, pred_next = temporal_model.forward_teacher(x_cont_model)
|
||||
temporal_loss = F.mse_loss(pred_next, x_cont_model[:, 1:, :])
|
||||
cond_cont = None
|
||||
if temporal_cond_dim > 0:
|
||||
cond_idx = [cont_cols.index(c) for c in type1_cols]
|
||||
cond_cont = x_cont[:, :, cond_idx]
|
||||
trend, pred_next = temporal_model.forward_teacher(x_cont_model, cond_cont=cond_cont)
|
||||
target_next = x_cont_model[:, 1:, :]
|
||||
if trend_mask is not None:
|
||||
mask = trend_mask.to(dtype=pred_next.dtype, device=pred_next.device)
|
||||
mse = (pred_next - target_next) ** 2
|
||||
temporal_loss = (mse * mask).sum() / torch.clamp(mask.sum() * mse.size(0) * mse.size(1), min=1.0)
|
||||
else:
|
||||
temporal_loss = F.mse_loss(pred_next, target_next)
|
||||
opt_temporal.zero_grad()
|
||||
temporal_loss.backward()
|
||||
if float(config.get("grad_clip", 0.0)) > 0:
|
||||
@@ -356,7 +388,9 @@ def main():
|
||||
if temporal_model is not None:
|
||||
temporal_model.eval()
|
||||
with torch.no_grad():
|
||||
trend, _ = temporal_model.forward_teacher(x_cont_model)
|
||||
trend, _ = temporal_model.forward_teacher(x_cont_model, cond_cont=cond_cont)
|
||||
if trend_mask is not None and trend is not None:
|
||||
trend = trend * trend_mask.to(dtype=trend.dtype, device=trend.device)
|
||||
x_cont_resid = x_cont_model if trend is None else x_cont_model - trend
|
||||
|
||||
bsz = x_cont.size(0)
|
||||
|
||||
Reference in New Issue
Block a user