update
This commit is contained in:
@@ -29,6 +29,13 @@ BATCH_SIZE = 2
|
||||
CLIP_K = 5.0
|
||||
|
||||
|
||||
def load_torch_state(path: str, device: str):
|
||||
try:
|
||||
return torch.load(path, map_location=device, weights_only=True)
|
||||
except TypeError:
|
||||
return torch.load(path, map_location=device)
|
||||
|
||||
|
||||
def load_vocab():
|
||||
with open(str(VOCAB_PATH), "r", encoding="utf-8") as f:
|
||||
return json.load(f)["vocab"]
|
||||
@@ -110,7 +117,7 @@ def main():
|
||||
eps_scale=eps_scale,
|
||||
).to(DEVICE)
|
||||
if MODEL_PATH.exists():
|
||||
model.load_state_dict(torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=True))
|
||||
model.load_state_dict(load_torch_state(str(MODEL_PATH), DEVICE))
|
||||
model.eval()
|
||||
|
||||
temporal_model = None
|
||||
@@ -136,7 +143,7 @@ def main():
|
||||
temporal_path = BASE_DIR / "results" / "temporal.pt"
|
||||
if not temporal_path.exists():
|
||||
raise SystemExit(f"missing temporal model file: {temporal_path}")
|
||||
temporal_model.load_state_dict(torch.load(str(temporal_path), map_location=DEVICE, weights_only=True))
|
||||
temporal_model.load_state_dict(load_torch_state(str(temporal_path), DEVICE))
|
||||
temporal_model.eval()
|
||||
|
||||
betas = cosine_beta_schedule(timesteps).to(DEVICE)
|
||||
|
||||
Reference in New Issue
Block a user