Update.
[picoclvr.git] / main.py
diff --git a/main.py b/main.py
index db982ca..0c2ff24 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -31,9 +31,11 @@ parser = argparse.ArgumentParser(
     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 )
 
-parser.add_argument("--task", type=str, default="picoclvr")
+parser.add_argument(
+    "--task", type=str, default="picoclvr", help="picoclvr, mnist, maze, snake"
+)
 
-parser.add_argument("--log_filename", type=str, default="train.log")
+parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
 
 parser.add_argument("--result_dir", type=str, default="results_default")
 
@@ -509,7 +511,7 @@ class TaskPicoCLVR(Task):
 
         image_name = os.path.join(args.result_dir, f"picoclvr_result_{n_epoch:04d}.png")
         torchvision.utils.save_image(
-            img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=1.0
+            img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
         )
         log_string(f"wrote {image_name}")
 
@@ -619,39 +621,83 @@ class TaskMaze(Task):
 
     def compute_error(self, model, split="train", nb_to_use=-1):
         nb_total, nb_correct = 0, 0
-        for input in task.batches(split, nb_to_use):
+        count = torch.zeros(
+            self.width * self.height,
+            self.width * self.height,
+            device=self.device,
+            dtype=torch.int64,
+        )
+        for input in tqdm.tqdm(
+            task.batches(split, nb_to_use),
+            dynamic_ncols=True,
+            desc=f"test-mazes",
+        ):
             result = input.clone()
             ar_mask = result.new_zeros(result.size())
             ar_mask[:, self.height * self.width :] = 1
             result *= 1 - ar_mask
             masked_inplace_autoregression(
-                model, self.batch_size, result, ar_mask, device=self.device
+                model,
+                self.batch_size,
+                result,
+                ar_mask,
+                progress_bar_desc=None,
+                device=self.device,
             )
             mazes, paths = self.seq2map(result)
-            nb_correct += maze.path_correctness(mazes, paths).long().sum()
+            path_correctness = maze.path_correctness(mazes, paths)
+            nb_correct += path_correctness.long().sum()
             nb_total += mazes.size(0)
 
-        return nb_total, nb_correct
+            optimal_path_lengths = (
+                (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
+            )
+            predicted_path_lengths = (
+                (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
+            )
+            optimal_path_lengths = optimal_path_lengths[path_correctness]
+            predicted_path_lengths = predicted_path_lengths[path_correctness]
+            count[optimal_path_lengths, predicted_path_lengths] += 1
+
+        if count.max() == 0:
+            count = None
+        else:
+            count = count[
+                : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
+            ]
+
+        return nb_total, nb_correct, count
 
     def produce_results(self, n_epoch, model):
         with torch.autograd.no_grad():
             t = model.training
             model.eval()
 
-            train_nb_total, train_nb_correct = self.compute_error(
+            train_nb_total, train_nb_correct, count = self.compute_error(
                 model, "train", nb_to_use=1000
             )
             log_string(
                 f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
             )
 
-            test_nb_total, test_nb_correct = self.compute_error(
+            test_nb_total, test_nb_correct, count = self.compute_error(
                 model, "test", nb_to_use=1000
             )
             log_string(
                 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
             )
 
+            if count is not None:
+                proportion_optimal = count.diagonal().sum().float() / count.sum()
+                log_string(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
+                with open(
+                    os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
+                ) as f:
+                    for i in range(count.size(0)):
+                        for j in range(count.size(1)):
+                            eol = " " if j < count.size(1) - 1 else "\n"
+                            f.write(f"{count[i,j]}{eol}")
+
             input = self.test_input[:48]
             result = input.clone()
             ar_mask = result.new_zeros(result.size())
@@ -671,6 +717,7 @@ class TaskMaze(Task):
                 target_paths=paths,
                 predicted_paths=predicted_paths,
                 path_correct=maze.path_correctness(mazes, predicted_paths),
+                path_optimal=maze.path_optimality(paths, predicted_paths),
             )
             log_string(f"wrote {filename}")