From: François Fleuret Date: Tue, 10 Sep 2024 07:05:29 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=276d3ec2f05b3e7061cb8389eb528719084a3905;p=culture.git Update. --- diff --git a/main.py b/main.py index 97d37ce..b7050df 100755 --- a/main.py +++ b/main.py @@ -1230,7 +1230,9 @@ if args.resume: filename = f"ae_{model.id:03d}.pth" try: - d = torch.load(os.path.join(args.result_dir, filename)) + d = torch.load( + os.path.join(args.result_dir, filename), map_location=main_device + ) model.load_state_dict(d["state_dict"]) model.optimizer.load_state_dict(d["optimizer_state_dict"]) model.test_accuracy = d["test_accuracy"] @@ -1378,9 +1380,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): else: records.append( - generate_ae_c_quizzes( - models, nb_c_quizzes_to_generate, records, gpus[0] - ) + generate_ae_c_quizzes(models, nb_c_quizzes_to_generate, gpus[0]) ) time_c_quizzes = int(time.perf_counter() - start_time)