X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=db982cabf6264e695e4775c543a9015ae64aa77b;hb=ed60c541ca2225d69df96c2a382bb83c947bfe0e;hp=967923644118d08fc52a4bb76056810cb49ce99e;hpb=cf7fcbb7a946c4d1f4d29a28e0eb04940d3b0f76;p=picoclvr.git diff --git a/main.py b/main.py index 9679236..db982ca 100755 --- a/main.py +++ b/main.py @@ -102,7 +102,7 @@ parser.add_argument("--snake_width", type=int, default=8) parser.add_argument("--snake_nb_colors", type=int, default=5) -parser.add_argument("--snake_length", type=int, default=400) +parser.add_argument("--snake_length", type=int, default=200) ###################################################################### @@ -143,8 +143,8 @@ default_args = { "batch_size": 25, }, "snake": { - "nb_epochs": 25, - "batch_size": 20, + "nb_epochs": 5, + "batch_size": 25, }, } @@ -173,15 +173,27 @@ for n in vars(args): ###################################################################### +# ra_mask is boolean, with 1s on the values to generate + + def masked_inplace_autoregression( - model, batch_size, input, ar_mask, forbidden_tokens=None, device=torch.device("cpu") + model, + batch_size, + input, + ar_mask, + forbidden_tokens=None, + progress_bar_desc="autoregression", + device=torch.device("cpu"), ): - for input, ar_mask in tqdm.tqdm( - zip(input.split(batch_size), ar_mask.split(batch_size)), - dynamic_ncols=True, - desc="autoregression", - total=input.size(0) // batch_size, - ): + batches = zip(input.split(batch_size), ar_mask.split(batch_size)) + if progress_bar_desc is not None: + tqdm.tqdm( + batches, + dynamic_ncols=True, + desc=progress_bar_desc, + total=input.size(0) // batch_size, + ) + for input, ar_mask in batches: i = (ar_mask.sum(0) > 0).nonzero() if i.min() > 0: model( @@ -317,6 +329,7 @@ class TaskPicoCLVR(Task): input, ar_masks, forbidden_tokens, + progress_bar_desc=None, device=self.device, ) model.train(t) @@ -689,7 +702,7 @@ class TaskSnake(Task): self.device = device self.prompt_length = prompt_length - self.train_input, self.train_prior_visits = snake.generate_sequences( + self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences( nb_train_samples, height, width, @@ -698,7 +711,7 @@ class TaskSnake(Task): prompt_length, self.device, ) - self.test_input, self.test_prior_visits = snake.generate_sequences( + self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences( nb_test_samples, height, width, @@ -975,9 +988,6 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): for input in task.batches(split="test"): input = input.to(device) - # input, loss_masks, true_images = task.excise_last_image(input) - # input, loss_masks = task.add_true_image(input, true_images, loss_masks) - output = model(mygpt.BracketedSequence(input)).x loss = F.cross_entropy(output.transpose(1, 2), input) acc_test_loss += loss.item() * input.size(0)