From b003cc9f89b7c3356f7d1e6c0c10b3dea249ef96 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 20 Jun 2023 16:13:20 +0200 Subject: [PATCH] Update. --- main.py | 51 +++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index 6e8ebff..acecfdd 100755 --- a/main.py +++ b/main.py @@ -101,7 +101,7 @@ parser.add_argument("--snake_width", type=int, default=8) parser.add_argument("--snake_nb_colors", type=int, default=3) -parser.add_argument("--snake_length", type=int, default=100) +parser.add_argument("--snake_length", type=int, default=400) ###################################################################### @@ -499,7 +499,7 @@ class TaskMNIST(Task): masked_inplace_autoregression( model, self.batch_size, results, ar_mask, device=self.device ) - image_name = os.path.join(args.result_dir, f"result_mnist_{n_epoch:04d}.png") + image_name = os.path.join(args.result_dir, f"mnist_result_{n_epoch:04d}.png") torchvision.utils.save_image( 1 - results.reshape(-1, 1, 28, 28) / 255.0, image_name, @@ -619,7 +619,7 @@ class TaskMaze(Task): mazes, paths = self.seq2map(input) _, predicted_paths = self.seq2map(result) - filename = os.path.join(args.result_dir, f"result_{n_epoch:04d}.png") + filename = os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.png") maze.save_image( filename, mazes=mazes, @@ -649,7 +649,7 @@ def generate_snake_sequences( ) snake_direction = torch.randint(4, (nb,), device=device) sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64) - count = torch.arange(nb, device=device) # [:,None] + i = torch.arange(nb, device=device) # [:,None] for l in range(length): # nb x 3 @@ -684,11 +684,10 @@ def generate_snake_sequences( ) # nb - i = torch.arange(val.size(0), device=device) j = val.argmax(1) snake_direction = snake_next_direction[i, j] - sequences[:, 2 * l] = worlds[count, snake_position[:, 0], snake_position[:, 1]] + sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4 sequences[:, 2 * l + 1] = snake_direction # nb x 2 @@ -696,8 +695,6 @@ def generate_snake_sequences( return sequences, worlds - # print(snake_position) - # generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20) # exit(0) @@ -744,6 +741,44 @@ class TaskSnake(Task): def vocabulary_size(self): return self.nb_codes + def produce_results(self, n_epoch, model): + with torch.autograd.no_grad(): + t = model.training + model.eval() + + def compute_nb_correct(input): + result = input.clone() + i = torch.arange(result.size(1), device=result.device) + ar_mask = torch.logical_and(i >= i.size(0) // 2, i % 2 == 0)[ + None, : + ].long() + result *= 1 - ar_mask + masked_inplace_autoregression( + model, self.batch_size, result, ar_mask, device=self.device + ) + + nb_total = ar_mask.sum() * input.size(0) + nb_correct = ((result == input).long() * ar_mask).sum() + + # nb_total = result.size(0) + # nb_correct = ((result - input).abs().sum(1) == 0).sum() + + return nb_total, nb_correct + + train_nb_total, train_nb_correct = compute_nb_correct(self.train_input) + + log_string( + f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" + ) + + test_nb_total, test_nb_correct = compute_nb_correct(self.test_input) + + log_string( + f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) + + model.train(t) + ###################################################################### -- 2.20.1