X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=45bddb762721c468d8990de9cc88044c5cbbd635;hb=f23843d33a4fa5a38f5034deab8f473793732ee3;hp=014481784a3979f015889149d671d203c1650e1b;hpb=d8ec2ebf14b7299b246456a440ff15e97cfae472;p=picoclvr.git diff --git a/main.py b/main.py index 0144817..45bddb7 100755 --- a/main.py +++ b/main.py @@ -187,6 +187,8 @@ def masked_inplace_autoregression( progress_bar_desc="autoregression", device=torch.device("cpu"), ): + # p = logits.softmax(1) + # entropy[:,s]= p.xlogy(p).sum(1) / math.log(2) batches = zip(input.split(batch_size), ar_mask.split(batch_size)) if progress_bar_desc is not None: tqdm.tqdm( @@ -511,7 +513,7 @@ class TaskPicoCLVR(Task): image_name = os.path.join(args.result_dir, f"picoclvr_result_{n_epoch:04d}.png") torchvision.utils.save_image( - img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=1.0 + img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0 ) log_string(f"wrote {image_name}") @@ -622,15 +624,27 @@ class TaskMaze(Task): def compute_error(self, model, split="train", nb_to_use=-1): nb_total, nb_correct = 0, 0 count = torch.zeros( - self.width * self.height, self.width * self.height, device=self.device + self.width * self.height, + self.width * self.height, + device=self.device, + dtype=torch.int64, ) - for input in task.batches(split, nb_to_use): + for input in tqdm.tqdm( + task.batches(split, nb_to_use), + dynamic_ncols=True, + desc=f"test-mazes", + ): result = input.clone() ar_mask = result.new_zeros(result.size()) ar_mask[:, self.height * self.width :] = 1 result *= 1 - ar_mask masked_inplace_autoregression( - model, self.batch_size, result, ar_mask, device=self.device + model, + self.batch_size, + result, + ar_mask, + progress_bar_desc=None, + device=self.device, ) mazes, paths = self.seq2map(result) path_correctness = maze.path_correctness(mazes, paths) @@ -676,6 +690,8 @@ class TaskMaze(Task): ) if count is not None: + proportion_optimal = count.diagonal().sum().float() / count.sum() + log_string(f"proportion_optimal_test {proportion_optimal*100:.02f}%") with open( os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.txt"), "w" ) as f: @@ -703,6 +719,7 @@ class TaskMaze(Task): target_paths=paths, predicted_paths=predicted_paths, path_correct=maze.path_correctness(mazes, predicted_paths), + path_optimal=maze.path_optimality(paths, predicted_paths), ) log_string(f"wrote {filename}")