X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=463d94ca1ae4dc43a054c875eba2fe686f0686bd;hb=233f57347c9560aec2f3cbaf001a8efa56a0243b;hp=0f3aaec3ff480ef8209e262baa61e150d23f4be5;hpb=68aa86a6645dfef3f919aad5732a1a09db77bfae;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 0f3aaec..463d94c 100755 --- a/tasks.py +++ b/tasks.py @@ -82,86 +82,6 @@ class PicoCLVR(Task): a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min() return z[:, a:b] - ###################### - # Not the cleanest part of the code - - # Extract the last image of each sequence, from the last - # included, and set to all the tokens from the beginning of - # that image to the end - def excise_last_image(self, input): - t_img, t_nul = self.token2id[""], self.token2id[""] - nb_img_tokens = self.height * self.width + 1 - - input = input.clone() - t = (input == t_img).long() - tail_masks = (t.cumsum(dim=1) == t.sum(dim=1, keepdim=True)).long() - i = (t * tail_masks).nonzero(as_tuple=True) - j = ( - i[0][:, None], - i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :], - ) - images = self.trim(input[j]) - input[j] = t_nul - loss_masks = 1 - tail_masks - input, loss_masks = self.trim((input, loss_masks)) - return input, loss_masks, images - - def add_true_image(self, input, images, loss_masks): - t_nul = self.token2id[""] - nb_img_tokens = self.height * self.width + 1 - input = F.pad(input, (0, nb_img_tokens), value=t_nul) - loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0) - t = (input == t_nul).long() - i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True) - j = ( - i[0][:, None], - i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :], - ) - input[j] = images - loss_masks[j] = 1 - input, loss_masks = self.trim((input, loss_masks)) - return input, loss_masks - - def add_generated_image(self, input, loss_masks, model, deterministic_synthesis): - t_img, t_nul = self.token2id[""], self.token2id[""] - nb_img_tokens = self.height * self.width + 1 - - input = F.pad(input, (0, nb_img_tokens), value=t_nul) - loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0) - t = (input == t_nul).long() - i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True) - input[i] = t_img - - j = ( - i[0][:, None], - i[1][:, None] - + 1 - + torch.arange(nb_img_tokens - 1, device=input.device)[None, :], - ) - ar_masks = input.new_zeros(input.size(), dtype=torch.int64) - ar_masks[j] = 1 - forbidden_tokens = ( - torch.arange(self.vocabulary_size(), device=input.device) == t_nul - ) - with torch.autograd.no_grad(): - t = model.training - model.eval() - masked_inplace_autoregression( - model, - self.batch_size, - input, - ar_masks, - deterministic_synthesis, - forbidden_tokens, - progress_bar_desc=None, - device=self.device, - ) - model.train(t) - - input, loss_masks = self.trim((input, loss_masks)) - - return input, loss_masks - ###################### def __init__( @@ -193,16 +113,6 @@ class PicoCLVR(Task): self.pruner_train = pruner_train self.pruner_eval = pruner_eval - param = { - "nb_train_samples": nb_train_samples, - "nb_test_samples": nb_test_samples, - "height": height, - "width": width, - "nb_colors": nb_colors, - "batch_size": batch_size, - "rng_state": list(torch.get_rng_state()), - } - if logger is not None: logger( f"generating {nb_train_samples+nb_test_samples} samples (can take some time)" @@ -225,6 +135,7 @@ class PicoCLVR(Task): tokens.sort() self.token2id = dict([(t, n) for n, t in enumerate(tokens)]) self.id2token = dict([(n, t) for n, t in enumerate(tokens)]) + self.t_img, self.t_nul = self.token2id[""], self.token2id[""] # Tokenize the train and test sets self.train_input = self.tensorize(self.train_descr) @@ -253,11 +164,20 @@ class PicoCLVR(Task): dynamic_ncols=True, desc=f"test-properties", ): - tape, loss_masks, _ = self.excise_last_image(input) - tape, loss_masks = self.add_generated_image( - tape, loss_masks, model, deterministic_synthesis + result = input.clone() + ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1) + result = (1 - ar_mask) * result + ar_mask * self.t_nul + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + progress_bar_desc=None, + device=self.device, ) - result_descr = self.detensorize(tape) + + result_descr = self.detensorize(result) np = picoclvr.nb_properties( result_descr, height=self.height, @@ -302,14 +222,23 @@ class PicoCLVR(Task): "red below yellow yellow below green green below blue red right yellow left green right blue left", "green bottom yellow bottom green left of blue yellow right of blue blue top", ]: - primer += [primer_descr] * nb_per_primer + primer += [primer_descr + " "] * nb_per_primer - tape = self.tensorize(primer) - loss_masks = 1 - (tape == self.token2id[""]).long() - tape, loss_masks = self.add_generated_image( - tape, loss_masks, model, deterministic_synthesis + result = self.tensorize(primer) + fill = result.new_full( + result.size()[:-1] + (self.height * self.width + 1,), self.t_nul ) - result_descr = self.detensorize(tape) + result = torch.cat((result, fill), 1) + ar_mask = (result == self.t_nul).long() + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + result_descr = self.detensorize(result) np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width) @@ -819,18 +748,21 @@ class Stack(Task): result = input.clone() stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) ar_mask = (result != input).long() - for n in range(result.size(0)): - logger( - f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" - ) - masked_inplace_autoregression( - model, - self.batch_size, - result, - ar_mask, - deterministic_synthesis, - device=self.device, - ) + + # for n in range(result.size(0)): + # logger( + # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" + # ) + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + for n in range(result.size(0)): logger( f"test_after {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" @@ -847,6 +779,20 @@ import expr class Expr(Task): + def tensorize(self, sequences): + len_max = max([len(x) for x in sequences]) + return torch.cat( + [ + torch.tensor( + [ + [self.char2id[c] for c in s + "#" * (len_max - len(s))] + for s in sequences + ] + ) + ], + 0, + ).to(self.device) + def __init__( self, nb_train_samples, @@ -871,43 +817,17 @@ class Expr(Task): nb_variables=nb_variables, length=sequence_length, ) - self.char2id = dict( - [ - (c, n) - for n, c in enumerate( - set("#" + "".join(train_sequences + test_sequences)) - ) - ] - ) + + symbols = list(set("#" + "".join(train_sequences + test_sequences))) + symbols.sort() + + self.char2id = dict([(c, n) for n, c in enumerate(symbols)]) self.id2char = dict([(n, c) for c, n in self.char2id.items()]) self.filler, self.space = self.char2id["#"], self.char2id[" "] - len_max = max([len(x) for x in train_sequences]) - self.train_input = torch.cat( - [ - torch.tensor( - [ - [self.char2id[c] for c in s + "#" * (len_max - len(s))] - for s in train_sequences - ] - ) - ], - 0, - ).to(device) - - len_max = max([len(x) for x in test_sequences]) - self.test_input = torch.cat( - [ - torch.tensor( - [ - [self.char2id[c] for c in s + "#" * (len_max - len(s))] - for s in test_sequences - ] - ) - ], - 0, - ).to(device) + self.train_input = self.tensorize(train_sequences) + self.test_input = self.tensorize(test_sequences) self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 @@ -933,7 +853,13 @@ class Expr(Task): return "".join([self.id2char[k.item()] for k in s]) def produce_results( - self, n_epoch, model, result_dir, logger, deterministic_synthesis + self, + n_epoch, + model, + result_dir, + logger, + deterministic_synthesis, + input_file=None, ): with torch.autograd.no_grad(): t = model.training @@ -1002,25 +928,36 @@ class Expr(Task): ############################################################## # Log a few generated sequences - input = self.test_input[:10] + if input_file is None: + input = self.test_input[:10] + else: + with open(input_file, "r") as f: + sequences = [e.strip() for e in f.readlines()] + sequences = [s + " " + "#" * 50 for s in sequences] + input = self.tensorize(sequences) + result = input.clone() - ar_mask = (result == self.space).long().cumsum(dim=1).clamp(max=1) + s = (result == self.space).long() + ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1) result = (1 - ar_mask) * result + ar_mask * self.filler + for n in range(result.size(0)): logger(f"test_before {self.seq2str(result[n])}") - masked_inplace_autoregression( - model, - self.batch_size, - result, - ar_mask, - deterministic_synthesis, - device=self.device, - ) + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + correct = (1 - ar_mask) * self.space + ar_mask * input for n in range(result.size(0)): comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else "" logger(f"test_after {self.seq2str(result[n])} {comment}") - logger(f"correct {self.seq2str(correct[n])}") + logger(f"truth {self.seq2str(correct[n])}") ############################################################## model.train(t)