This commit is contained in:
Mingzhe Yang
2026-02-04 03:53:17 +08:00
parent 2072351c0d
commit 10c0721ee1
6 changed files with 1134 additions and 104 deletions

View File

@@ -29,6 +29,13 @@ BATCH_SIZE = 2
CLIP_K = 5.0
def load_torch_state(path: str, device: str):
try:
return torch.load(path, map_location=device, weights_only=True)
except TypeError:
return torch.load(path, map_location=device)
def load_vocab():
with open(str(VOCAB_PATH), "r", encoding="utf-8") as f:
return json.load(f)["vocab"]
@@ -110,7 +117,7 @@ def main():
eps_scale=eps_scale,
).to(DEVICE)
if MODEL_PATH.exists():
model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True))
model.load_state_dict(load_torch_state(str(MODEL_PATH), DEVICE))
model.eval()
temporal_model = None
@@ -136,7 +143,7 @@ def main():
temporal_path = BASE_DIR / "results" / "temporal.pt"
if not temporal_path.exists():
raise SystemExit(f"missing temporal model file: {temporal_path}")
temporal_model.load_state_dict(torch.load(str(temporal_path), map_location=DEVICE, weights_only=True))
temporal_model.load_state_dict(load_torch_state(str(temporal_path), DEVICE))
temporal_model.eval()
betas = cosine_beta_schedule(timesteps).to(DEVICE)