filename = f"ae_{model.id:03d}.pth"
try:
- d = torch.load(os.path.join(args.result_dir, filename))
+ d = torch.load(
+ os.path.join(args.result_dir, filename), map_location=main_device
+ )
model.load_state_dict(d["state_dict"])
model.optimizer.load_state_dict(d["optimizer_state_dict"])
model.test_accuracy = d["test_accuracy"]
else:
records.append(
- generate_ae_c_quizzes(
- models, nb_c_quizzes_to_generate, records, gpus[0]
- )
+ generate_ae_c_quizzes(models, nb_c_quizzes_to_generate, gpus[0])
)
time_c_quizzes = int(time.perf_counter() - start_time)