Update
authorFrançois Fleuret <francois@fleuret.org>
Thu, 23 Mar 2023 08:48:55 +0000 (09:48 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 23 Mar 2023 08:48:55 +0000 (09:48 +0100)
beaver.py

index dca97cc..bd17365 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -312,15 +312,17 @@ def oneshot(gpt, task):
         scores = scores.reshape(-1, task.height, task.width)
         mazes = mazes.reshape(-1, task.height, task.width)
         targets = targets.reshape(-1, task.height, task.width)
+        filename = (
+            f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png"
+        )
         maze.save_image(
-            os.path.join(
-                args.result_dir,
-                f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png",
-            ),
+            os.path.join(args.result_dir, filename),
             mazes=mazes,
             score_paths=scores,
             score_truth=targets,
         )
+        log_string(f"wrote {filename}")
+
         # -------------------
 
     gpt.train(t)
@@ -471,13 +473,15 @@ class TaskMaze(Task):
 
             mazes, paths = self.seq2map(input)
             _, predicted_paths = self.seq2map(result)
+            filename = f"result_{n_epoch:04d}.png"
             maze.save_image(
-                os.path.join(args.result_dir, f"result_{n_epoch:04d}.png"),
+                os.path.join(args.result_dir, filename),
                 mazes=mazes,
                 target_paths=paths,
                 predicted_paths=predicted_paths,
                 path_correct=maze.path_correctness(mazes, predicted_paths),
             )
+            log_string(f"wrote {filename}")
 
             model.train(t)