X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=e723866e080924a6dc4818e39eac84dadffa0d0c;hb=fc3bfa393cd7c527c5342441da8cce24a71a63f2;hp=6e8ebff6ad9bd449e1e8f5b144ac6eab4016b328;hpb=a35b7a5ebee0b58fc76b64c13d7550eb71bc4567;p=picoclvr.git diff --git a/main.py b/main.py index 6e8ebff..e723866 100755 --- a/main.py +++ b/main.py @@ -8,7 +8,7 @@ # torch.backends.cuda.matmul.allow_tf23 # torch.autocast(torch.bfloat16) -import math, sys, argparse, time, tqdm, itertools, os +import math, sys, argparse, time, tqdm, os import torch, torchvision from torch import nn @@ -27,7 +27,8 @@ else: ###################################################################### parser = argparse.ArgumentParser( - description="An implementation of GPT with cache to solve a toy geometric reasoning task." + description="An implementation of GPT with cache.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("--task", type=str, default="picoclvr") @@ -40,7 +41,7 @@ parser.add_argument("--seed", type=int, default=0) parser.add_argument("--nb_epochs", type=int, default=25) -parser.add_argument("--batch_size", type=int, default=25) +parser.add_argument("--batch_size", type=int, default=None) parser.add_argument("--nb_train_samples", type=int, default=250000) @@ -101,7 +102,7 @@ parser.add_argument("--snake_width", type=int, default=8) parser.add_argument("--snake_nb_colors", type=int, default=3) -parser.add_argument("--snake_length", type=int, default=100) +parser.add_argument("--snake_length", type=int, default=400) ###################################################################### @@ -128,6 +129,28 @@ if args.seed >= 0: ###################################################################### +default_args = { + "picoclvr": { + "batch_size": 25, + }, + "mnist": { + "batch_size": 10, + }, + "maze": { + "batch_size": 25, + }, + "snake": { + "batch_size": 20, + }, +} + +if args.task in default_args: + for k, v in default_args[args.task].items(): + if getattr(args, k) is None: + setattr(args, k, v) + +###################################################################### + def log_string(s): t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime()) @@ -499,7 +522,7 @@ class TaskMNIST(Task): masked_inplace_autoregression( model, self.batch_size, results, ar_mask, device=self.device ) - image_name = os.path.join(args.result_dir, f"result_mnist_{n_epoch:04d}.png") + image_name = os.path.join(args.result_dir, f"mnist_result_{n_epoch:04d}.png") torchvision.utils.save_image( 1 - results.reshape(-1, 1, 28, 28) / 255.0, image_name, @@ -619,7 +642,7 @@ class TaskMaze(Task): mazes, paths = self.seq2map(input) _, predicted_paths = self.seq2map(result) - filename = os.path.join(args.result_dir, f"result_{n_epoch:04d}.png") + filename = os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.png") maze.save_image( filename, mazes=mazes, @@ -639,6 +662,8 @@ def generate_snake_sequences( nb, height, width, nb_colors, length, device=torch.device("cpu") ): worlds = torch.randint(nb_colors, (nb, height, width), device=device) + nb_prior_visits = torch.zeros(nb, height, width, device=device) + # nb x 2 snake_position = torch.cat( ( @@ -649,7 +674,10 @@ def generate_snake_sequences( ) snake_direction = torch.randint(4, (nb,), device=device) sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64) - count = torch.arange(nb, device=device) # [:,None] + sequences_prior_visits = torch.zeros( + nb, 2 * length, device=device, dtype=torch.int64 + ) + i = torch.arange(nb, device=device) # [:,None] for l in range(length): # nb x 3 @@ -680,23 +708,27 @@ def generate_snake_sequences( ), ).float() val = ( - torch.rand_like(val) * val * torch.tensor([[1.0, 4.0, 1.0]], device=device) + # The multiplicative factors bias toward moving forward + torch.rand_like(val) + * val + * torch.tensor([[1.0, 2.0, 1.0]], device=device) ) # nb - i = torch.arange(val.size(0), device=device) j = val.argmax(1) snake_direction = snake_next_direction[i, j] - sequences[:, 2 * l] = worlds[count, snake_position[:, 0], snake_position[:, 1]] + sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4 + sequences_prior_visits[:, 2 * l] = nb_prior_visits[ + i, snake_position[:, 0], snake_position[:, 1] + ] + nb_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1 sequences[:, 2 * l + 1] = snake_direction # nb x 2 snake_position = snake_next_position[i, j] - return sequences, worlds - - # print(snake_position) + return sequences, sequences_prior_visits # generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20) @@ -720,10 +752,10 @@ class TaskSnake(Task): self.width = width self.device = device - self.train_input, self.train_worlds = generate_snake_sequences( + self.train_input, self.train_prior_visits = generate_snake_sequences( nb_train_samples, height, width, nb_colors, length, self.device ) - self.test_input, self.test_worlds = generate_snake_sequences( + self.test_input, self.test_prior_visits = generate_snake_sequences( nb_test_samples, height, width, nb_colors, length, self.device ) @@ -744,6 +776,51 @@ class TaskSnake(Task): def vocabulary_size(self): return self.nb_codes + def produce_results(self, n_epoch, model): + with torch.autograd.no_grad(): + t = model.training + model.eval() + + def compute_nb_correct(input, prior_visits): + result = input.clone() + i = torch.arange(result.size(1), device=result.device)[None, :] + ar_mask = torch.logical_and(i >= i.size(0) // 2, i % 2 == 0).long() + result *= 1 - ar_mask + masked_inplace_autoregression( + model, self.batch_size, result, ar_mask, device=self.device + ) + + nb_total = ( + (prior_visits > 0) * ar_mask + ).sum() + + nb_correct = ( + (result == input).long() * (prior_visits > 0) * ar_mask + ).sum() + + # nb_total = result.size(0) + # nb_correct = ((result - input).abs().sum(1) == 0).sum() + + return nb_total, nb_correct + + train_nb_total, train_nb_correct = compute_nb_correct( + self.train_input, self.train_prior_visits + ) + + log_string( + f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" + ) + + test_nb_total, test_nb_correct = compute_nb_correct( + self.test_input, self.test_prior_visits + ) + + log_string( + f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) + + model.train(t) + ######################################################################