From: François Fleuret Date: Sat, 11 Mar 2023 23:08:44 +0000 (+0100) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=3518f58472ceb6cf7ea3cdb62aabc7a368501348;p=beaver.git Update. --- diff --git a/beaver.py b/beaver.py index a289867..c68fe76 100755 --- a/beaver.py +++ b/beaver.py @@ -227,6 +227,7 @@ class TaskMaze(Task): result = input.clone() ar_mask = result.new_zeros(result.size()) ar_mask[:, self.height * self.width :] = 1 + result *= 1-ar_mask masked_inplace_autoregression(model, self.batch_size, result, ar_mask) mazes, paths = self.seq2map(result) nb_correct += maze.path_correctness(mazes, paths).long().sum() @@ -256,8 +257,8 @@ class TaskMaze(Task): input = self.test_input[:32] result = input.clone() ar_mask = result.new_zeros(result.size()) - ar_mask[:, self.height * self.width :] = 1 + result *= 1-ar_mask masked_inplace_autoregression(model, self.batch_size, result, ar_mask) mazes, paths = self.seq2map(input)