From 3518f58472ceb6cf7ea3cdb62aabc7a368501348 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 12 Mar 2023 00:08:44 +0100 Subject: [PATCH] Update. --- beaver.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/beaver.py b/beaver.py index a289867..c68fe76 100755 --- a/beaver.py +++ b/beaver.py @@ -227,6 +227,7 @@ class TaskMaze(Task): result = input.clone() ar_mask = result.new_zeros(result.size()) ar_mask[:, self.height * self.width :] = 1 + result *= 1-ar_mask masked_inplace_autoregression(model, self.batch_size, result, ar_mask) mazes, paths = self.seq2map(result) nb_correct += maze.path_correctness(mazes, paths).long().sum() @@ -256,8 +257,8 @@ class TaskMaze(Task): input = self.test_input[:32] result = input.clone() ar_mask = result.new_zeros(result.size()) - ar_mask[:, self.height * self.width :] = 1 + result *= 1-ar_mask masked_inplace_autoregression(model, self.batch_size, result, ar_mask) mazes, paths = self.seq2map(input) -- 2.39.5