X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=57c68018eca2ea2abba35478af62a15b91a5ab10;hb=refs%2Fheads%2Fmaster;hp=218ff36e0f7e37f67de3e4e227457b61d6414a60;hpb=408f2335af43590ee2d99c3286cbe3762c76887a;p=mygptrnn.git diff --git a/tasks.py b/tasks.py index 218ff36..57c6801 100755 --- a/tasks.py +++ b/tasks.py @@ -106,7 +106,7 @@ class SandBox(Task): device ), self.test_ar_mask.to(device) - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item() # A bit of paranoia never hurts assert self.nb_codes <= max_nb_codes @@ -579,7 +579,7 @@ class Maze(Task): ) self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device)) - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item() def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -756,7 +756,7 @@ class Snake(Task): self.device, ) - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item() def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -871,7 +871,7 @@ class Stack(Task): counts = F.one_hot(counts).sum(0) logger(f"test_pop_stack_counts {counts}") - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item() def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -1078,7 +1078,7 @@ class RPL(Task): s = " ".join(seq) logger(f"example_seq {s}") - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item() def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -1308,7 +1308,7 @@ class Expr(Task): self.train_input = self.tensorize(train_sequences) self.test_input = self.tensorize(test_sequences) - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item() def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -1639,7 +1639,7 @@ class QMLP(Task): for e in self.test_ref_test_errors: f.write(f"{e}\n") - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item() def batches(self, split="train", desc=None): assert split in {"train", "test"}