X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=beaver.py;h=c68fe7660afc7d9eeb01379576bf9eb0dc2f51c4;hb=3518f58472ceb6cf7ea3cdb62aabc7a368501348;hp=4d4f98d9621d77290a7b1593af5568b113a1306c;hpb=d44d0605fed828b8cea08c8e1c5bda7e4528ea97;p=beaver.git diff --git a/beaver.py b/beaver.py index 4d4f98d..c68fe76 100755 --- a/beaver.py +++ b/beaver.py @@ -38,9 +38,11 @@ parser.add_argument("--seed", type=int, default=0) parser.add_argument("--nb_epochs", type=int, default=25) -parser.add_argument("--batch_size", type=int, default=100) +parser.add_argument("--nb_train_samples", type=int, default=200000) -parser.add_argument("--data_size", type=int, default=-1) +parser.add_argument("--nb_test_samples", type=int, default=50000) + +parser.add_argument("--batch_size", type=int, default=25) parser.add_argument("--optim", type=str, default="adam") @@ -73,11 +75,11 @@ parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") ############################## # maze options -parser.add_argument("--world_height", type=int, default=13) +parser.add_argument("--maze_height", type=int, default=13) -parser.add_argument("--world_width", type=int, default=21) +parser.add_argument("--maze_width", type=int, default=21) -parser.add_argument("--world_nb_walls", type=int, default=15) +parser.add_argument("--maze_nb_walls", type=int, default=15) ###################################################################### @@ -170,16 +172,23 @@ class TaskMaze(Task): s = s.reshape(s.size(0), -1, self.height, self.width) return (s[:, k] for k in range(s.size(1))) - def __init__(self, batch_size, height, width, nb_walls, device=torch.device("cpu")): + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + height, + width, + nb_walls, + device=torch.device("cpu"), + ): self.batch_size = batch_size self.height = height self.width = width self.device = device - nb = args.data_size if args.data_size > 0 else 250000 - mazes_train, paths_train = maze.create_maze_data( - (4 * nb) // 5, + nb_train_samples, height=height, width=width, nb_walls=nb_walls, @@ -190,7 +199,7 @@ class TaskMaze(Task): self.nb_codes = self.train_input.max() + 1 mazes_test, paths_test = maze.create_maze_data( - nb // 5, + nb_test_samples, height=height, width=width, nb_walls=nb_walls, @@ -199,9 +208,11 @@ class TaskMaze(Task): mazes_test, paths_test = mazes_test.to(device), paths_test.to(device) self.test_input = self.map2seq(mazes_test, paths_test) - def batches(self, split="train"): + def batches(self, split="train", nb_to_use=-1): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] for batch in tqdm.tqdm( input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}" ): @@ -210,12 +221,13 @@ class TaskMaze(Task): def vocabulary_size(self): return self.nb_codes - def compute_error(self, model, split="train"): + def compute_error(self, model, split="train", nb_to_use=-1): nb_total, nb_correct = 0, 0 - for input in task.batches(split): + for input in task.batches(split, nb_to_use): result = input.clone() ar_mask = result.new_zeros(result.size()) ar_mask[:, self.height * self.width :] = 1 + result *= 1-ar_mask masked_inplace_autoregression(model, self.batch_size, result, ar_mask) mazes, paths = self.seq2map(result) nb_correct += maze.path_correctness(mazes, paths).long().sum() @@ -224,26 +236,42 @@ class TaskMaze(Task): return nb_total, nb_correct def produce_results(self, n_epoch, model): - train_nb_total, train_nb_correct = self.compute_error(model, "train") - log_string( - f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" - ) - - test_nb_total, test_nb_correct = self.compute_error(model, "test") - log_string( - f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" - ) - - input = self.test_input[:32] - result = input.clone() - ar_mask = result.new_zeros(result.size()) + with torch.autograd.no_grad(): + t = model.training + model.eval() + + train_nb_total, train_nb_correct = self.compute_error( + model, "train", nb_to_use=1000 + ) + log_string( + f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" + ) + + test_nb_total, test_nb_correct = self.compute_error( + model, "test", nb_to_use=1000 + ) + log_string( + f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) + + input = self.test_input[:32] + result = input.clone() + ar_mask = result.new_zeros(result.size()) + ar_mask[:, self.height * self.width :] = 1 + result *= 1-ar_mask + masked_inplace_autoregression(model, self.batch_size, result, ar_mask) - ar_mask[:, self.height * self.width :] = 1 - masked_inplace_autoregression(model, self.batch_size, result, ar_mask) + mazes, paths = self.seq2map(input) + _, predicted_paths = self.seq2map(result) + maze.save_image( + f"result_{n_epoch:04d}.png", + mazes, + paths, + predicted_paths, + maze.path_correctness(mazes, predicted_paths), + ) - mazes, paths = self.seq2map(input) - _, predicted_paths = self.seq2map(result) - maze.save_image(f"result_{n_epoch:04d}.png", mazes, paths, predicted_paths) + model.train(t) ###################################################################### @@ -252,10 +280,12 @@ log_string(f"device {device}") task = TaskMaze( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, - height=args.world_height, - width=args.world_width, - nb_walls=args.world_nb_walls, + height=args.maze_height, + width=args.maze_width, + nb_walls=args.maze_nb_walls, device=device, )