Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 21 Sep 2024 06:05:26 +0000 (08:05 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 21 Sep 2024 06:05:26 +0000 (08:05 +0200)
grids.py
main.py

index 78d9297..ac25781 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -140,7 +140,7 @@ class Grids(problem.Problem):
     # dots = False
 
     grid_gray = 240
-    thickness = 0
+    thickness = 1
     background_gray = 240
     dots = False
 
@@ -287,9 +287,9 @@ class Grids(problem.Problem):
     ######################################################################
 
     def vocabulary_size(self):
-        warnings.warn("hack +4 to keep the vocabulary size unchanged", RuntimeWarning)
-        # return self.nb_colors+4
-        return self.nb_colors
+        warnings.warn("hack +4 to keep the vocabulary size unchanged", RuntimeWarning)
+        return self.nb_colors + 4
+        return self.nb_colors
 
     def grid2img(self, x, scale=15, grids=True):
         m = torch.logical_and(x >= 0, x < self.nb_colors).long()
diff --git a/main.py b/main.py
index 5dceefc..0311450 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -631,6 +631,7 @@ def max_nb_mistakes_on_one_grid(quizzes, prediction):
 
 def evaluate_quizzes(quizzes, models, with_hints, local_device):
     nb_correct, nb_wrong = 0, 0
+    quizzes = quizzes.to(local_device)
 
     for model in models:
         model = copy.deepcopy(model).to(local_device).eval()
@@ -955,8 +956,9 @@ for n_epoch in range(current_epoch, args.nb_epochs):
             quizzes=train_c_quizzes,
             models=models,
             with_hints=False,
-            local_device=local_device,
+            local_device=main_device,
         )
+        nb_correct = nb_correct.to("cpu")
 
         test_c_quizzes = train_c_quizzes[nb_correct >= args.nb_have_to_be_correct]