X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=218ff36e0f7e37f67de3e4e227457b61d6414a60;hb=26af4588b06ed463a4f9b9bcc4b527dd4c864d49;hp=4777a11676447c4137683ec988bd011a0ad69d81;hpb=e56873a0cb64555cbd47e44cdca0ce991765a5fc;p=mygptrnn.git diff --git a/tasks.py b/tasks.py index 4777a11..218ff36 100755 --- a/tasks.py +++ b/tasks.py @@ -250,7 +250,13 @@ class PicoCLVR(Task): # Make a list of strings from a tensor def detensorize(self, x): - return [" ".join([self.id2token[t.item()] for t in r]) for r in x] + def id2token(t): + try: + return self.id2token[t.item()] + except KeyError: + return "?" + + return [" ".join([id2token(t) for t in r]) for r in x] # trim all the tensors in the tuple z to remove as much token from # left and right in the first tensor. If z is a tuple, all its @@ -888,7 +894,10 @@ class Stack(Task): def compute_nb_correct(input): result = input.clone() stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) + ar_mask = (result != input).long() + result *= 1 - ar_mask + masked_inplace_autoregression( model, self.batch_size, @@ -923,10 +932,12 @@ class Stack(Task): stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) ar_mask = (result != input).long() - # for n in range(result.size(0)): - # logger( - # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" - # ) + for n in range(result.size(0)): + logger( + f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" + ) + + result *= 1 - ar_mask masked_inplace_autoregression( model, @@ -1448,7 +1459,13 @@ class Grid(Task): # Make a list of strings from a tensor def tensor2str(self, x): - return [" ".join([self.id2token[t.item()] for t in r]) for r in x] + def id2token(t): + try: + return self.id2token[t.item()] + except KeyError: + return "?" + + return [" ".join([id2token(t) for t in r]) for r in x] # trim all the tensors in the tuple z to remove as much token from # left and right in the first tensor. If z is a tuple, all its @@ -1473,6 +1490,8 @@ class Grid(Task): nb_test_samples, batch_size, size, + nb_shapes, + nb_colors, logger=None, device=torch.device("cpu"), ): @@ -1480,7 +1499,9 @@ class Grid(Task): self.device = device self.batch_size = batch_size - self.grid_factory = grid.GridFactory(size=size) + self.grid_factory = grid.GridFactory( + size=size, nb_shapes=nb_shapes, nb_colors=nb_colors + ) if logger is not None: logger(