From 4d256526d60f6760ba3e13a72da1d38aa67f60ec Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 21 Sep 2024 08:05:26 +0200 Subject: [PATCH] Update. --- grids.py | 8 ++++---- main.py | 4 +++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/grids.py b/grids.py index 78d9297..ac25781 100755 --- a/grids.py +++ b/grids.py @@ -140,7 +140,7 @@ class Grids(problem.Problem): # dots = False grid_gray = 240 - thickness = 0 + thickness = 1 background_gray = 240 dots = False @@ -287,9 +287,9 @@ class Grids(problem.Problem): ###################################################################### def vocabulary_size(self): - # warnings.warn("hack +4 to keep the vocabulary size unchanged", RuntimeWarning) - # return self.nb_colors+4 - return self.nb_colors + warnings.warn("hack +4 to keep the vocabulary size unchanged", RuntimeWarning) + return self.nb_colors + 4 + # return self.nb_colors def grid2img(self, x, scale=15, grids=True): m = torch.logical_and(x >= 0, x < self.nb_colors).long() diff --git a/main.py b/main.py index 5dceefc..0311450 100755 --- a/main.py +++ b/main.py @@ -631,6 +631,7 @@ def max_nb_mistakes_on_one_grid(quizzes, prediction): def evaluate_quizzes(quizzes, models, with_hints, local_device): nb_correct, nb_wrong = 0, 0 + quizzes = quizzes.to(local_device) for model in models: model = copy.deepcopy(model).to(local_device).eval() @@ -955,8 +956,9 @@ for n_epoch in range(current_epoch, args.nb_epochs): quizzes=train_c_quizzes, models=models, with_hints=False, - local_device=local_device, + local_device=main_device, ) + nb_correct = nb_correct.to("cpu") test_c_quizzes = train_c_quizzes[nb_correct >= args.nb_have_to_be_correct] -- 2.39.5