X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=014481784a3979f015889149d671d203c1650e1b;hb=d8ec2ebf14b7299b246456a440ff15e97cfae472;hp=7cb8d4f96ab03542205db29e333c8b22e176a1ae;hpb=74311726e42dccb8bc096e86a7e9000576099bab;p=picoclvr.git diff --git a/main.py b/main.py index 7cb8d4f..0144817 100755 --- a/main.py +++ b/main.py @@ -31,9 +31,11 @@ parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) -parser.add_argument("--task", type=str, default="picoclvr") +parser.add_argument( + "--task", type=str, default="picoclvr", help="picoclvr, mnist, maze, snake" +) -parser.add_argument("--log_filename", type=str, default="train.log") +parser.add_argument("--log_filename", type=str, default="train.log", help=" ") parser.add_argument("--result_dir", type=str, default="results_default") @@ -173,15 +175,27 @@ for n in vars(args): ###################################################################### +# ra_mask is boolean, with 1s on the values to generate + + def masked_inplace_autoregression( - model, batch_size, input, ar_mask, forbidden_tokens=None, device=torch.device("cpu") + model, + batch_size, + input, + ar_mask, + forbidden_tokens=None, + progress_bar_desc="autoregression", + device=torch.device("cpu"), ): - for input, ar_mask in tqdm.tqdm( - zip(input.split(batch_size), ar_mask.split(batch_size)), - dynamic_ncols=True, - desc="autoregression", - total=input.size(0) // batch_size, - ): + batches = zip(input.split(batch_size), ar_mask.split(batch_size)) + if progress_bar_desc is not None: + tqdm.tqdm( + batches, + dynamic_ncols=True, + desc=progress_bar_desc, + total=input.size(0) // batch_size, + ) + for input, ar_mask in batches: i = (ar_mask.sum(0) > 0).nonzero() if i.min() > 0: model( @@ -317,6 +331,7 @@ class TaskPicoCLVR(Task): input, ar_masks, forbidden_tokens, + progress_bar_desc=None, device=self.device, ) model.train(t) @@ -606,6 +621,9 @@ class TaskMaze(Task): def compute_error(self, model, split="train", nb_to_use=-1): nb_total, nb_correct = 0, 0 + count = torch.zeros( + self.width * self.height, self.width * self.height, device=self.device + ) for input in task.batches(split, nb_to_use): result = input.clone() ar_mask = result.new_zeros(result.size()) @@ -615,30 +633,57 @@ class TaskMaze(Task): model, self.batch_size, result, ar_mask, device=self.device ) mazes, paths = self.seq2map(result) - nb_correct += maze.path_correctness(mazes, paths).long().sum() + path_correctness = maze.path_correctness(mazes, paths) + nb_correct += path_correctness.long().sum() nb_total += mazes.size(0) - return nb_total, nb_correct + optimal_path_lengths = ( + (input[:, self.height * self.width :] == maze.v_path).long().sum(1) + ) + predicted_path_lengths = ( + (result[:, self.height * self.width :] == maze.v_path).long().sum(1) + ) + optimal_path_lengths = optimal_path_lengths[path_correctness] + predicted_path_lengths = predicted_path_lengths[path_correctness] + count[optimal_path_lengths, predicted_path_lengths] += 1 + + if count.max() == 0: + count = None + else: + count = count[ + : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1 + ] + + return nb_total, nb_correct, count def produce_results(self, n_epoch, model): with torch.autograd.no_grad(): t = model.training model.eval() - train_nb_total, train_nb_correct = self.compute_error( + train_nb_total, train_nb_correct, count = self.compute_error( model, "train", nb_to_use=1000 ) log_string( f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" ) - test_nb_total, test_nb_correct = self.compute_error( + test_nb_total, test_nb_correct, count = self.compute_error( model, "test", nb_to_use=1000 ) log_string( f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" ) + if count is not None: + with open( + os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.txt"), "w" + ) as f: + for i in range(count.size(0)): + for j in range(count.size(1)): + eol = " " if j < count.size(1) - 1 else "\n" + f.write(f"{count[i,j]}{eol}") + input = self.test_input[:48] result = input.clone() ar_mask = result.new_zeros(result.size()) @@ -975,9 +1020,6 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): for input in task.batches(split="test"): input = input.to(device) - # input, loss_masks, true_images = task.excise_last_image(input) - # input, loss_masks = task.add_true_image(input, true_images, loss_masks) - output = model(mygpt.BracketedSequence(input)).x loss = F.cross_entropy(output.transpose(1, 2), input) acc_test_loss += loss.item() * input.size(0)