X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=62a88918e6436fdc68e97f31a6851631ccfb91df;hb=f680fa1486b0a70c37f0951cedd7b5c56b5808bb;hp=0f3aaec3ff480ef8209e262baa61e150d23f4be5;hpb=68aa86a6645dfef3f919aad5732a1a09db77bfae;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 0f3aaec..62a8891 100755 --- a/tasks.py +++ b/tasks.py @@ -82,86 +82,6 @@ class PicoCLVR(Task): a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min() return z[:, a:b] - ###################### - # Not the cleanest part of the code - - # Extract the last image of each sequence, from the last - # included, and set to all the tokens from the beginning of - # that image to the end - def excise_last_image(self, input): - t_img, t_nul = self.token2id[""], self.token2id[""] - nb_img_tokens = self.height * self.width + 1 - - input = input.clone() - t = (input == t_img).long() - tail_masks = (t.cumsum(dim=1) == t.sum(dim=1, keepdim=True)).long() - i = (t * tail_masks).nonzero(as_tuple=True) - j = ( - i[0][:, None], - i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :], - ) - images = self.trim(input[j]) - input[j] = t_nul - loss_masks = 1 - tail_masks - input, loss_masks = self.trim((input, loss_masks)) - return input, loss_masks, images - - def add_true_image(self, input, images, loss_masks): - t_nul = self.token2id[""] - nb_img_tokens = self.height * self.width + 1 - input = F.pad(input, (0, nb_img_tokens), value=t_nul) - loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0) - t = (input == t_nul).long() - i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True) - j = ( - i[0][:, None], - i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :], - ) - input[j] = images - loss_masks[j] = 1 - input, loss_masks = self.trim((input, loss_masks)) - return input, loss_masks - - def add_generated_image(self, input, loss_masks, model, deterministic_synthesis): - t_img, t_nul = self.token2id[""], self.token2id[""] - nb_img_tokens = self.height * self.width + 1 - - input = F.pad(input, (0, nb_img_tokens), value=t_nul) - loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0) - t = (input == t_nul).long() - i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True) - input[i] = t_img - - j = ( - i[0][:, None], - i[1][:, None] - + 1 - + torch.arange(nb_img_tokens - 1, device=input.device)[None, :], - ) - ar_masks = input.new_zeros(input.size(), dtype=torch.int64) - ar_masks[j] = 1 - forbidden_tokens = ( - torch.arange(self.vocabulary_size(), device=input.device) == t_nul - ) - with torch.autograd.no_grad(): - t = model.training - model.eval() - masked_inplace_autoregression( - model, - self.batch_size, - input, - ar_masks, - deterministic_synthesis, - forbidden_tokens, - progress_bar_desc=None, - device=self.device, - ) - model.train(t) - - input, loss_masks = self.trim((input, loss_masks)) - - return input, loss_masks - ###################### def __init__( @@ -193,16 +113,6 @@ class PicoCLVR(Task): self.pruner_train = pruner_train self.pruner_eval = pruner_eval - param = { - "nb_train_samples": nb_train_samples, - "nb_test_samples": nb_test_samples, - "height": height, - "width": width, - "nb_colors": nb_colors, - "batch_size": batch_size, - "rng_state": list(torch.get_rng_state()), - } - if logger is not None: logger( f"generating {nb_train_samples+nb_test_samples} samples (can take some time)" @@ -225,6 +135,7 @@ class PicoCLVR(Task): tokens.sort() self.token2id = dict([(t, n) for n, t in enumerate(tokens)]) self.id2token = dict([(n, t) for n, t in enumerate(tokens)]) + self.t_img, self.t_nul = self.token2id[""], self.token2id[""] # Tokenize the train and test sets self.train_input = self.tensorize(self.train_descr) @@ -253,11 +164,20 @@ class PicoCLVR(Task): dynamic_ncols=True, desc=f"test-properties", ): - tape, loss_masks, _ = self.excise_last_image(input) - tape, loss_masks = self.add_generated_image( - tape, loss_masks, model, deterministic_synthesis + result = input.clone() + ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1) + result = (1 - ar_mask) * result + ar_mask * self.t_nul + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + progress_bar_desc=None, + device=self.device, ) - result_descr = self.detensorize(tape) + + result_descr = self.detensorize(result) np = picoclvr.nb_properties( result_descr, height=self.height, @@ -302,14 +222,23 @@ class PicoCLVR(Task): "red below yellow yellow below green green below blue red right yellow left green right blue left", "green bottom yellow bottom green left of blue yellow right of blue blue top", ]: - primer += [primer_descr] * nb_per_primer + primer += [primer_descr + " "] * nb_per_primer - tape = self.tensorize(primer) - loss_masks = 1 - (tape == self.token2id[""]).long() - tape, loss_masks = self.add_generated_image( - tape, loss_masks, model, deterministic_synthesis + result = self.tensorize(primer) + fill = result.new_full( + result.size()[:-1] + (self.height * self.width + 1,), self.t_nul + ) + result = torch.cat((result, fill), 1) + ar_mask = (result == self.t_nul).long() + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, ) - result_descr = self.detensorize(tape) + result_descr = self.detensorize(result) np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)