Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 20 Jun 2023 14:13:20 +0000 (16:13 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 20 Jun 2023 14:13:20 +0000 (16:13 +0200)
main.py

diff --git a/main.py b/main.py
index 6e8ebff..acecfdd 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -101,7 +101,7 @@ parser.add_argument("--snake_width", type=int, default=8)
 
 parser.add_argument("--snake_nb_colors", type=int, default=3)
 
-parser.add_argument("--snake_length", type=int, default=100)
+parser.add_argument("--snake_length", type=int, default=400)
 
 ######################################################################
 
@@ -499,7 +499,7 @@ class TaskMNIST(Task):
         masked_inplace_autoregression(
             model, self.batch_size, results, ar_mask, device=self.device
         )
-        image_name = os.path.join(args.result_dir, f"result_mnist_{n_epoch:04d}.png")
+        image_name = os.path.join(args.result_dir, f"mnist_result_{n_epoch:04d}.png")
         torchvision.utils.save_image(
             1 - results.reshape(-1, 1, 28, 28) / 255.0,
             image_name,
@@ -619,7 +619,7 @@ class TaskMaze(Task):
             mazes, paths = self.seq2map(input)
             _, predicted_paths = self.seq2map(result)
 
-            filename = os.path.join(args.result_dir, f"result_{n_epoch:04d}.png")
+            filename = os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.png")
             maze.save_image(
                 filename,
                 mazes=mazes,
@@ -649,7 +649,7 @@ def generate_snake_sequences(
     )
     snake_direction = torch.randint(4, (nb,), device=device)
     sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64)
-    count = torch.arange(nb, device=device)  # [:,None]
+    i = torch.arange(nb, device=device)  # [:,None]
 
     for l in range(length):
         # nb x 3
@@ -684,11 +684,10 @@ def generate_snake_sequences(
         )
 
         # nb
-        i = torch.arange(val.size(0), device=device)
         j = val.argmax(1)
         snake_direction = snake_next_direction[i, j]
 
-        sequences[:, 2 * l] = worlds[count, snake_position[:, 0], snake_position[:, 1]]
+        sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
         sequences[:, 2 * l + 1] = snake_direction
 
         # nb x 2
@@ -696,8 +695,6 @@ def generate_snake_sequences(
 
     return sequences, worlds
 
-    # print(snake_position)
-
 
 # generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
 # exit(0)
@@ -744,6 +741,44 @@ class TaskSnake(Task):
     def vocabulary_size(self):
         return self.nb_codes
 
+    def produce_results(self, n_epoch, model):
+        with torch.autograd.no_grad():
+            t = model.training
+            model.eval()
+
+            def compute_nb_correct(input):
+                result = input.clone()
+                i = torch.arange(result.size(1), device=result.device)
+                ar_mask = torch.logical_and(i >= i.size(0) // 2, i % 2 == 0)[
+                    None, :
+                ].long()
+                result *= 1 - ar_mask
+                masked_inplace_autoregression(
+                    model, self.batch_size, result, ar_mask, device=self.device
+                )
+
+                nb_total = ar_mask.sum() * input.size(0)
+                nb_correct = ((result == input).long() * ar_mask).sum()
+
+                # nb_total = result.size(0)
+                # nb_correct = ((result - input).abs().sum(1) == 0).sum()
+
+                return nb_total, nb_correct
+
+            train_nb_total, train_nb_correct = compute_nb_correct(self.train_input)
+
+            log_string(
+                f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
+            )
+
+            test_nb_total, test_nb_correct = compute_nb_correct(self.test_input)
+
+            log_string(
+                f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+            )
+
+            model.train(t)
+
 
 ######################################################################