a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
return z[:, a:b]
- ######################
- # Not the cleanest part of the code
-
- # Extract the last image of each sequence, from the last <img>
- # included, and set to <nul> all the tokens from the beginning of
- # that image to the end
- def excise_last_image(self, input):
- t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
- nb_img_tokens = self.height * self.width + 1
-
- input = input.clone()
- t = (input == t_img).long()
- tail_masks = (t.cumsum(dim=1) == t.sum(dim=1, keepdim=True)).long()
- i = (t * tail_masks).nonzero(as_tuple=True)
- j = (
- i[0][:, None],
- i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
- )
- images = self.trim(input[j])
- input[j] = t_nul
- loss_masks = 1 - tail_masks
- input, loss_masks = self.trim((input, loss_masks))
- return input, loss_masks, images
-
- def add_true_image(self, input, images, loss_masks):
- t_nul = self.token2id["<nul>"]
- nb_img_tokens = self.height * self.width + 1
- input = F.pad(input, (0, nb_img_tokens), value=t_nul)
- loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
- t = (input == t_nul).long()
- i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
- j = (
- i[0][:, None],
- i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
- )
- input[j] = images
- loss_masks[j] = 1
- input, loss_masks = self.trim((input, loss_masks))
- return input, loss_masks
-
- def add_generated_image(self, input, loss_masks, model, deterministic_synthesis):
- t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
- nb_img_tokens = self.height * self.width + 1
-
- input = F.pad(input, (0, nb_img_tokens), value=t_nul)
- loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
- t = (input == t_nul).long()
- i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
- input[i] = t_img
-
- j = (
- i[0][:, None],
- i[1][:, None]
- + 1
- + torch.arange(nb_img_tokens - 1, device=input.device)[None, :],
- )
- ar_masks = input.new_zeros(input.size(), dtype=torch.int64)
- ar_masks[j] = 1
- forbidden_tokens = (
- torch.arange(self.vocabulary_size(), device=input.device) == t_nul
- )
- with torch.autograd.no_grad():
- t = model.training
- model.eval()
- masked_inplace_autoregression(
- model,
- self.batch_size,
- input,
- ar_masks,
- deterministic_synthesis,
- forbidden_tokens,
- progress_bar_desc=None,
- device=self.device,
- )
- model.train(t)
-
- input, loss_masks = self.trim((input, loss_masks))
-
- return input, loss_masks
-
######################
def __init__(
self.pruner_train = pruner_train
self.pruner_eval = pruner_eval
- param = {
- "nb_train_samples": nb_train_samples,
- "nb_test_samples": nb_test_samples,
- "height": height,
- "width": width,
- "nb_colors": nb_colors,
- "batch_size": batch_size,
- "rng_state": list(torch.get_rng_state()),
- }
-
if logger is not None:
logger(
f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
tokens.sort()
self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
+ self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
# Tokenize the train and test sets
self.train_input = self.tensorize(self.train_descr)
dynamic_ncols=True,
desc=f"test-properties",
):
- tape, loss_masks, _ = self.excise_last_image(input)
- tape, loss_masks = self.add_generated_image(
- tape, loss_masks, model, deterministic_synthesis
+ result = input.clone()
+ ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
+ result = (1 - ar_mask) * result + ar_mask * self.t_nul
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ progress_bar_desc=None,
+ device=self.device,
)
- result_descr = self.detensorize(tape)
+
+ result_descr = self.detensorize(result)
np = picoclvr.nb_properties(
result_descr,
height=self.height,
"red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
"green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
]:
- primer += [primer_descr] * nb_per_primer
+ primer += [primer_descr + " <img>"] * nb_per_primer
- tape = self.tensorize(primer)
- loss_masks = 1 - (tape == self.token2id["<nul>"]).long()
- tape, loss_masks = self.add_generated_image(
- tape, loss_masks, model, deterministic_synthesis
+ result = self.tensorize(primer)
+ ar_mask = (result == self.t_nul).long()
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ device=self.device,
)
- result_descr = self.detensorize(tape)
+ result_descr = self.detensorize(result)
np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)