X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=acecfdd311048fdbc1aedc49becae0b703783fbc;hb=b003cc9f89b7c3356f7d1e6c0c10b3dea249ef96;hp=3db87df303a3cb55e8f2755f7b953485596da00f;hpb=d363acfa35249faaa1fc6574e50c1c59da277141;p=picoclvr.git diff --git a/main.py b/main.py index 3db87df..acecfdd 100755 --- a/main.py +++ b/main.py @@ -92,6 +92,17 @@ parser.add_argument("--maze_width", type=int, default=21) parser.add_argument("--maze_nb_walls", type=int, default=15) +############################## +# Snake options + +parser.add_argument("--snake_height", type=int, default=6) + +parser.add_argument("--snake_width", type=int, default=8) + +parser.add_argument("--snake_nb_colors", type=int, default=3) + +parser.add_argument("--snake_length", type=int, default=400) + ###################################################################### args = parser.parse_args() @@ -488,7 +499,7 @@ class TaskMNIST(Task): masked_inplace_autoregression( model, self.batch_size, results, ar_mask, device=self.device ) - image_name = os.path.join(args.result_dir, f"result_mnist_{n_epoch:04d}.png") + image_name = os.path.join(args.result_dir, f"mnist_result_{n_epoch:04d}.png") torchvision.utils.save_image( 1 - results.reshape(-1, 1, 28, 28) / 255.0, image_name, @@ -608,7 +619,7 @@ class TaskMaze(Task): mazes, paths = self.seq2map(input) _, predicted_paths = self.seq2map(result) - filename = os.path.join(args.result_dir, f"result_{n_epoch:04d}.png") + filename = os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.png") maze.save_image( filename, mazes=mazes, @@ -623,6 +634,72 @@ class TaskMaze(Task): ###################################################################### + +def generate_snake_sequences( + nb, height, width, nb_colors, length, device=torch.device("cpu") +): + worlds = torch.randint(nb_colors, (nb, height, width), device=device) + # nb x 2 + snake_position = torch.cat( + ( + torch.randint(height, (nb, 1), device=device), + torch.randint(width, (nb, 1), device=device), + ), + 1, + ) + snake_direction = torch.randint(4, (nb,), device=device) + sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64) + i = torch.arange(nb, device=device) # [:,None] + + for l in range(length): + # nb x 3 + snake_next_direction = torch.cat( + ( + (snake_direction[:, None] - 1) % 4, + snake_direction[:, None], + (snake_direction[:, None] + 1) % 4, + ), + 1, + ) + + # nb x 3 + vh = (snake_next_direction + 1) % 2 * (snake_next_direction - 1) + vw = snake_next_direction % 2 * (snake_next_direction - 2) + + # nb x 3 x 2 + snake_next_speed = torch.cat((vh[:, :, None], vw[:, :, None]), 2) + snake_next_position = snake_position[:, None, :] + snake_next_speed + + # nb x 3 + val = torch.logical_and( + torch.logical_and( + snake_next_position[:, :, 0] >= 0, snake_next_position[:, :, 0] < height + ), + torch.logical_and( + snake_next_position[:, :, 1] >= 0, snake_next_position[:, :, 1] < width + ), + ).float() + val = ( + torch.rand_like(val) * val * torch.tensor([[1.0, 4.0, 1.0]], device=device) + ) + + # nb + j = val.argmax(1) + snake_direction = snake_next_direction[i, j] + + sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4 + sequences[:, 2 * l + 1] = snake_direction + + # nb x 2 + snake_position = snake_next_position[i, j] + + return sequences, worlds + + +# generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20) +# exit(0) + + class TaskSnake(Task): def __init__( self, @@ -631,7 +708,8 @@ class TaskSnake(Task): batch_size, height, width, - nb_walls, + nb_colors, + length, device=torch.device("cpu"), ): self.batch_size = batch_size @@ -639,10 +717,14 @@ class TaskSnake(Task): self.width = width self.device = device - # self.train_input = - # self.test_input = + self.train_input, self.train_worlds = generate_snake_sequences( + nb_train_samples, height, width, nb_colors, length, self.device + ) + self.test_input, self.test_worlds = generate_snake_sequences( + nb_test_samples, height, width, nb_colors, length, self.device + ) - self.nb_codes = max(self.train_input.max(), self.train_input.max()) + 1 + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -656,6 +738,47 @@ class TaskSnake(Task): ): yield batch + def vocabulary_size(self): + return self.nb_codes + + def produce_results(self, n_epoch, model): + with torch.autograd.no_grad(): + t = model.training + model.eval() + + def compute_nb_correct(input): + result = input.clone() + i = torch.arange(result.size(1), device=result.device) + ar_mask = torch.logical_and(i >= i.size(0) // 2, i % 2 == 0)[ + None, : + ].long() + result *= 1 - ar_mask + masked_inplace_autoregression( + model, self.batch_size, result, ar_mask, device=self.device + ) + + nb_total = ar_mask.sum() * input.size(0) + nb_correct = ((result == input).long() * ar_mask).sum() + + # nb_total = result.size(0) + # nb_correct = ((result - input).abs().sum(1) == 0).sum() + + return nb_total, nb_correct + + train_nb_total, train_nb_correct = compute_nb_correct(self.train_input) + + log_string( + f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" + ) + + test_nb_total, test_nb_correct = compute_nb_correct(self.test_input) + + log_string( + f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) + + model.train(t) + ###################################################################### @@ -708,6 +831,18 @@ elif args.task == "maze": device=device, ) +elif args.task == "snake": + task = TaskSnake( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + height=args.snake_height, + width=args.snake_width, + nb_colors=args.snake_nb_colors, + length=args.snake_length, + device=device, + ) + else: raise ValueError(f"Unknown task {args.task}")