From 1e896de274e6baa690d90e1b0bfe9b45f983c1a9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 17 Sep 2024 08:27:36 +0200 Subject: [PATCH] Update. --- main.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 899a099..e921ccd 100755 --- a/main.py +++ b/main.py @@ -804,14 +804,14 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device): # train - one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True) - one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=False) + one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=True) + one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=False) # predict quizzes = quiz_machine.quiz_set(150, c_quizzes, args.c_quiz_multiplier) input, targets, mask = batch_prediction(quizzes.to(local_device)) - result = predict(model, input, targets, mask).to("cpu") + result = predict(model, input, targets, mask, local_device=local_device).to("cpu") mask = mask.to("cpu") correct = (quizzes == result).min(dim=1).values.long() correct_parts = (2 * correct - 1)[:, None] * mask.reshape(mask.size(0), 4, -1)[ @@ -830,7 +830,7 @@ def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device): # generate - result = generate(model, 25).to("cpu") + result = generate(model, 25, local_device=local_device).to("cpu") quiz_machine.problem.save_quizzes_as_image( args.result_dir, f"culture_generation_{n_epoch}_{model.id}.png", -- 2.39.5