X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=218ff36e0f7e37f67de3e4e227457b61d6414a60;hb=26af4588b06ed463a4f9b9bcc4b527dd4c864d49;hp=58638ed95ae343842ab810727b4aefd9fe0daabe;hpb=4395f9a90218819997c706de9505cda1c86ad507;p=mygptrnn.git diff --git a/tasks.py b/tasks.py index 58638ed..218ff36 100755 --- a/tasks.py +++ b/tasks.py @@ -58,7 +58,7 @@ def masked_inplace_autoregression( class Task: - def batches(self, split="train"): + def batches(self, split="train", desc=None): pass def vocabulary_size(self): @@ -250,7 +250,13 @@ class PicoCLVR(Task): # Make a list of strings from a tensor def detensorize(self, x): - return [" ".join([self.id2token[t.item()] for t in r]) for r in x] + def id2token(t): + try: + return self.id2token[t.item()] + except KeyError: + return "?" + + return [" ".join([id2token(t) for t in r]) for r in x] # trim all the tensors in the tuple z to remove as much token from # left and right in the first tensor. If z is a tuple, all its @@ -328,7 +334,7 @@ class PicoCLVR(Task): self.train_input = self.tensorize(self.train_descr) self.test_input = self.tensorize(self.test_descr) - def batches(self, split="train"): + def batches(self, split="train", desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input for batch in tqdm.tqdm( @@ -888,7 +894,10 @@ class Stack(Task): def compute_nb_correct(input): result = input.clone() stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) + ar_mask = (result != input).long() + result *= 1 - ar_mask + masked_inplace_autoregression( model, self.batch_size, @@ -923,10 +932,12 @@ class Stack(Task): stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) ar_mask = (result != input).long() - # for n in range(result.size(0)): - # logger( - # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" - # ) + for n in range(result.size(0)): + logger( + f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" + ) + + result *= 1 - ar_mask masked_inplace_autoregression( model, @@ -1448,7 +1459,13 @@ class Grid(Task): # Make a list of strings from a tensor def tensor2str(self, x): - return [" ".join([self.id2token[t.item()] for t in r]) for r in x] + def id2token(t): + try: + return self.id2token[t.item()] + except KeyError: + return "?" + + return [" ".join([id2token(t) for t in r]) for r in x] # trim all the tensors in the tuple z to remove as much token from # left and right in the first tensor. If z is a tuple, all its @@ -1473,6 +1490,8 @@ class Grid(Task): nb_test_samples, batch_size, size, + nb_shapes, + nb_colors, logger=None, device=torch.device("cpu"), ): @@ -1480,7 +1499,9 @@ class Grid(Task): self.device = device self.batch_size = batch_size - self.grid_factory = grid.GridFactory(size=size) + self.grid_factory = grid.GridFactory( + size=size, nb_shapes=nb_shapes, nb_colors=nb_colors + ) if logger is not None: logger( @@ -1515,11 +1536,13 @@ class Grid(Task): self.train_input = self.str2tensor(self.train_descr) self.test_input = self.str2tensor(self.test_descr) - def batches(self, split="train"): + def batches(self, split="train", desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input + if desc is None: + desc = f"epoch-{split}" for batch in tqdm.tqdm( - input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}" + input.split(self.batch_size), dynamic_ncols=True, desc=desc ): yield self.trim(batch) @@ -1618,11 +1641,13 @@ class QMLP(Task): self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 - def batches(self, split="train"): + def batches(self, split="train", desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input + if desc is None: + desc = f"epoch-{split}" for batch in tqdm.tqdm( - input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}" + input.split(self.batch_size), dynamic_ncols=True, desc=desc ): yield batch