X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=04b8f84577097a7ef3de515fa90f6a461429e7e5;hb=1406521aecaab783d2ea267b0f973ac17e091bf7;hp=affc8cdcfae81d694e46c1e7afdcb2cc2d0656b4;hpb=5aee50805cfad1dd49bbf30b30fe65b05e03de78;p=picoclvr.git diff --git a/tasks.py b/tasks.py index affc8cd..04b8f84 100755 --- a/tasks.py +++ b/tasks.py @@ -225,6 +225,10 @@ class PicoCLVR(Task): primer += [primer_descr + " "] * nb_per_primer result = self.tensorize(primer) + fill = result.new_full( + result.size()[:-1] + (self.height * self.width + 1,), self.t_nul + ) + result = torch.cat((result, fill), 1) ar_mask = (result == self.t_nul).long() masked_inplace_autoregression( model, @@ -744,18 +748,21 @@ class Stack(Task): result = input.clone() stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) ar_mask = (result != input).long() - for n in range(result.size(0)): - logger( - f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" - ) - masked_inplace_autoregression( - model, - self.batch_size, - result, - ar_mask, - deterministic_synthesis, - device=self.device, - ) + + # for n in range(result.size(0)): + # logger( + # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" + # ) + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + for n in range(result.size(0)): logger( f"test_after {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" @@ -772,6 +779,20 @@ import expr class Expr(Task): + def tensorize(self, sequences): + len_max = max([len(x) for x in sequences]) + return torch.cat( + [ + torch.tensor( + [ + [self.char2id[c] for c in s + "#" * (len_max - len(s))] + for s in sequences + ] + ) + ], + 0, + ).to(self.device) + def __init__( self, nb_train_samples, @@ -788,51 +809,24 @@ class Expr(Task): nb_train_samples, nb_variables=nb_variables, length=sequence_length, - # length=2 * sequence_length, - # randomize_length=True, ) + test_sequences = expr.generate_sequences( nb_test_samples, nb_variables=nb_variables, length=sequence_length, ) - self.char2id = dict( - [ - (c, n) - for n, c in enumerate( - set("#" + "".join(train_sequences + test_sequences)) - ) - ] - ) + + symbols = list(set("#" + "".join(train_sequences + test_sequences))) + symbols.sort() + + self.char2id = dict([(c, n) for n, c in enumerate(symbols)]) self.id2char = dict([(n, c) for c, n in self.char2id.items()]) self.filler, self.space = self.char2id["#"], self.char2id[" "] - len_max = max([len(x) for x in train_sequences]) - self.train_input = torch.cat( - [ - torch.tensor( - [ - [self.char2id[c] for c in s + "#" * (len_max - len(s))] - for s in train_sequences - ] - ) - ], - 0, - ).to(device) - - len_max = max([len(x) for x in test_sequences]) - self.test_input = torch.cat( - [ - torch.tensor( - [ - [self.char2id[c] for c in s + "#" * (len_max - len(s))] - for s in test_sequences - ] - ) - ], - 0, - ).to(device) + self.train_input = self.tensorize(train_sequences) + self.test_input = self.tensorize(test_sequences) self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 @@ -858,7 +852,13 @@ class Expr(Task): return "".join([self.id2char[k.item()] for k in s]) def produce_results( - self, n_epoch, model, result_dir, logger, deterministic_synthesis + self, + n_epoch, + model, + result_dir, + logger, + deterministic_synthesis, + input_file=None, ): with torch.autograd.no_grad(): t = model.training @@ -910,7 +910,7 @@ class Expr(Task): test_nb_correct, test_nb_delta, test_nb_missed, - ) = compute_nb_correct(self.test_input[:1000]) + ) = compute_nb_correct(self.test_input[:10000]) logger( f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" @@ -927,25 +927,36 @@ class Expr(Task): ############################################################## # Log a few generated sequences - input = self.test_input[:10] + if input_file is None: + input = self.test_input[:10] + else: + with open(input_file, "r") as f: + sequences = [e.strip() for e in f.readlines()] + sequences = [s + " " + "#" * 50 for s in sequences] + input = self.tensorize(sequences) + result = input.clone() - ar_mask = (result == self.space).long().cumsum(dim=1).clamp(max=1) + s = (result == self.space).long() + ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1) result = (1 - ar_mask) * result + ar_mask * self.filler + for n in range(result.size(0)): logger(f"test_before {self.seq2str(result[n])}") - masked_inplace_autoregression( - model, - self.batch_size, - result, - ar_mask, - deterministic_synthesis, - device=self.device, - ) + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + correct = (1 - ar_mask) * self.space + ar_mask * input for n in range(result.size(0)): comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else "" logger(f"test_after {self.seq2str(result[n])} {comment}") - logger(f"correct {self.seq2str(correct[n])}") + logger(f"truth {self.seq2str(correct[n])}") ############################################################## model.train(t)