X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=0c2ff24dd8907845418b13c5867ba41469c83529;hb=3c97745cdf9ae30a87903e3039e38c868e136d6e;hp=014481784a3979f015889149d671d203c1650e1b;hpb=d8ec2ebf14b7299b246456a440ff15e97cfae472;p=picoclvr.git diff --git a/main.py b/main.py index 0144817..0c2ff24 100755 --- a/main.py +++ b/main.py @@ -511,7 +511,7 @@ class TaskPicoCLVR(Task): image_name = os.path.join(args.result_dir, f"picoclvr_result_{n_epoch:04d}.png") torchvision.utils.save_image( - img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=1.0 + img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0 ) log_string(f"wrote {image_name}") @@ -622,15 +622,27 @@ class TaskMaze(Task): def compute_error(self, model, split="train", nb_to_use=-1): nb_total, nb_correct = 0, 0 count = torch.zeros( - self.width * self.height, self.width * self.height, device=self.device + self.width * self.height, + self.width * self.height, + device=self.device, + dtype=torch.int64, ) - for input in task.batches(split, nb_to_use): + for input in tqdm.tqdm( + task.batches(split, nb_to_use), + dynamic_ncols=True, + desc=f"test-mazes", + ): result = input.clone() ar_mask = result.new_zeros(result.size()) ar_mask[:, self.height * self.width :] = 1 result *= 1 - ar_mask masked_inplace_autoregression( - model, self.batch_size, result, ar_mask, device=self.device + model, + self.batch_size, + result, + ar_mask, + progress_bar_desc=None, + device=self.device, ) mazes, paths = self.seq2map(result) path_correctness = maze.path_correctness(mazes, paths) @@ -676,6 +688,8 @@ class TaskMaze(Task): ) if count is not None: + proportion_optimal = count.diagonal().sum().float() / count.sum() + log_string(f"proportion_optimal_test {proportion_optimal*100:.02f}%") with open( os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.txt"), "w" ) as f: @@ -703,6 +717,7 @@ class TaskMaze(Task): target_paths=paths, predicted_paths=predicted_paths, path_correct=maze.path_correctness(mazes, predicted_paths), + path_optimal=maze.path_optimality(paths, predicted_paths), ) log_string(f"wrote {filename}")