update
This commit is contained in:
@@ -73,6 +73,13 @@ def parse_args():
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_torch_state(path: str, device: str):
|
||||
try:
|
||||
return torch.load(path, map_location=device, weights_only=True)
|
||||
except TypeError:
|
||||
return torch.load(path, map_location=device)
|
||||
|
||||
|
||||
# 使用 platform_utils 中的 resolve_device 函数
|
||||
|
||||
|
||||
@@ -193,9 +200,9 @@ def main():
|
||||
).to(device)
|
||||
if args.use_ema and os.path.exists(args.model_path.replace("model.pt", "model_ema.pt")):
|
||||
ema_path = args.model_path.replace("model.pt", "model_ema.pt")
|
||||
model.load_state_dict(torch.load(ema_path, map_location=device, weights_only=True))
|
||||
model.load_state_dict(load_torch_state(ema_path, device))
|
||||
else:
|
||||
model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only=True))
|
||||
model.load_state_dict(load_torch_state(args.model_path, device))
|
||||
model.eval()
|
||||
|
||||
temporal_model = None
|
||||
@@ -221,7 +228,7 @@ def main():
|
||||
temporal_path = Path(args.model_path).with_name("temporal.pt")
|
||||
if not temporal_path.exists():
|
||||
raise SystemExit(f"missing temporal model file: {temporal_path}")
|
||||
temporal_model.load_state_dict(torch.load(temporal_path, map_location=device, weights_only=True))
|
||||
temporal_model.load_state_dict(load_torch_state(str(temporal_path), device))
|
||||
temporal_model.eval()
|
||||
|
||||
betas = cosine_beta_schedule(args.timesteps).to(device)
|
||||
|
||||
Reference in New Issue
Block a user