X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=57c68018eca2ea2abba35478af62a15b91a5ab10;hb=HEAD;hp=58638ed95ae343842ab810727b4aefd9fe0daabe;hpb=4395f9a90218819997c706de9505cda1c86ad507;p=mygptrnn.git diff --git a/tasks.py b/tasks.py index 58638ed..57c6801 100755 --- a/tasks.py +++ b/tasks.py @@ -58,7 +58,7 @@ def masked_inplace_autoregression( class Task: - def batches(self, split="train"): + def batches(self, split="train", desc=None): pass def vocabulary_size(self): @@ -106,7 +106,7 @@ class SandBox(Task): device ), self.test_ar_mask.to(device) - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item() # A bit of paranoia never hurts assert self.nb_codes <= max_nb_codes @@ -250,7 +250,13 @@ class PicoCLVR(Task): # Make a list of strings from a tensor def detensorize(self, x): - return [" ".join([self.id2token[t.item()] for t in r]) for r in x] + def id2token(t): + try: + return self.id2token[t.item()] + except KeyError: + return "?" + + return [" ".join([id2token(t) for t in r]) for r in x] # trim all the tensors in the tuple z to remove as much token from # left and right in the first tensor. If z is a tuple, all its @@ -328,7 +334,7 @@ class PicoCLVR(Task): self.train_input = self.tensorize(self.train_descr) self.test_input = self.tensorize(self.test_descr) - def batches(self, split="train"): + def batches(self, split="train", desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input for batch in tqdm.tqdm( @@ -573,7 +579,7 @@ class Maze(Task): ) self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device)) - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item() def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -750,7 +756,7 @@ class Snake(Task): self.device, ) - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item() def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -865,7 +871,7 @@ class Stack(Task): counts = F.one_hot(counts).sum(0) logger(f"test_pop_stack_counts {counts}") - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item() def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -888,7 +894,10 @@ class Stack(Task): def compute_nb_correct(input): result = input.clone() stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) + ar_mask = (result != input).long() + result *= 1 - ar_mask + masked_inplace_autoregression( model, self.batch_size, @@ -923,10 +932,12 @@ class Stack(Task): stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) ar_mask = (result != input).long() - # for n in range(result.size(0)): - # logger( - # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" - # ) + for n in range(result.size(0)): + logger( + f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" + ) + + result *= 1 - ar_mask masked_inplace_autoregression( model, @@ -1067,7 +1078,7 @@ class RPL(Task): s = " ".join(seq) logger(f"example_seq {s}") - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item() def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -1297,7 +1308,7 @@ class Expr(Task): self.train_input = self.tensorize(train_sequences) self.test_input = self.tensorize(test_sequences) - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item() def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -1448,7 +1459,13 @@ class Grid(Task): # Make a list of strings from a tensor def tensor2str(self, x): - return [" ".join([self.id2token[t.item()] for t in r]) for r in x] + def id2token(t): + try: + return self.id2token[t.item()] + except KeyError: + return "?" + + return [" ".join([id2token(t) for t in r]) for r in x] # trim all the tensors in the tuple z to remove as much token from # left and right in the first tensor. If z is a tuple, all its @@ -1473,6 +1490,8 @@ class Grid(Task): nb_test_samples, batch_size, size, + nb_shapes, + nb_colors, logger=None, device=torch.device("cpu"), ): @@ -1480,7 +1499,9 @@ class Grid(Task): self.device = device self.batch_size = batch_size - self.grid_factory = grid.GridFactory(size=size) + self.grid_factory = grid.GridFactory( + size=size, nb_shapes=nb_shapes, nb_colors=nb_colors + ) if logger is not None: logger( @@ -1515,11 +1536,13 @@ class Grid(Task): self.train_input = self.str2tensor(self.train_descr) self.test_input = self.str2tensor(self.test_descr) - def batches(self, split="train"): + def batches(self, split="train", desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input + if desc is None: + desc = f"epoch-{split}" for batch in tqdm.tqdm( - input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}" + input.split(self.batch_size), dynamic_ncols=True, desc=desc ): yield self.trim(batch) @@ -1616,13 +1639,15 @@ class QMLP(Task): for e in self.test_ref_test_errors: f.write(f"{e}\n") - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item() - def batches(self, split="train"): + def batches(self, split="train", desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input + if desc is None: + desc = f"epoch-{split}" for batch in tqdm.tqdm( - input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}" + input.split(self.batch_size), dynamic_ncols=True, desc=desc ): yield batch