Update
authorFrançois Fleuret <francois@fleuret.org>
Sun, 12 Mar 2023 07:12:00 +0000 (08:12 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 12 Mar 2023 07:12:00 +0000 (08:12 +0100)
beaver.py

index c68fe76..b0e8a78 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -227,7 +227,7 @@ class TaskMaze(Task):
             result = input.clone()
             ar_mask = result.new_zeros(result.size())
             ar_mask[:, self.height * self.width :] = 1
-            result *= 1-ar_mask
+            result *= 1 - ar_mask
             masked_inplace_autoregression(model, self.batch_size, result, ar_mask)
             mazes, paths = self.seq2map(result)
             nb_correct += maze.path_correctness(mazes, paths).long().sum()
@@ -258,13 +258,13 @@ class TaskMaze(Task):
             result = input.clone()
             ar_mask = result.new_zeros(result.size())
             ar_mask[:, self.height * self.width :] = 1
-            result *= 1-ar_mask
+            result *= 1 - ar_mask
             masked_inplace_autoregression(model, self.batch_size, result, ar_mask)
 
             mazes, paths = self.seq2map(input)
             _, predicted_paths = self.seq2map(result)
             maze.save_image(
-                f"result_{n_epoch:04d}.png",
+                os.path.join(args.result_dir, f"result_{n_epoch:04d}.png"),
                 mazes,
                 paths,
                 predicted_paths,