This commit is contained in:
Mingzhe Yang
2026-02-04 03:53:17 +08:00
parent 2072351c0d
commit 10c0721ee1
6 changed files with 1134 additions and 104 deletions

View File

@@ -73,6 +73,13 @@ def parse_args():
return parser.parse_args()
def load_torch_state(path: str, device: str):
try:
return torch.load(path, map_location=device, weights_only=True)
except TypeError:
return torch.load(path, map_location=device)
# 使用 platform_utils 中的 resolve_device 函数
@@ -193,9 +200,9 @@ def main():
).to(device)
if args.use_ema and os.path.exists(args.model_path.replace("model.pt", "model_ema.pt")):
ema_path = args.model_path.replace("model.pt", "model_ema.pt")
model.load_state_dict(torch.load(ema_path, map_location=device, weights_only=True))
model.load_state_dict(load_torch_state(ema_path, device))
else:
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
model.load_state_dict(load_torch_state(args.model_path, device))
model.eval()
temporal_model = None
@@ -221,7 +228,7 @@ def main():
temporal_path = Path(args.model_path).with_name("temporal.pt")
if not temporal_path.exists():
raise SystemExit(f"missing temporal model file: {temporal_path}")
temporal_model.load_state_dict(torch.load(temporal_path, map_location=device, weights_only=True))
temporal_model.load_state_dict(load_torch_state(str(temporal_path), device))
temporal_model.eval()
betas = cosine_beta_schedule(args.timesteps).to(device)