device
), self.test_ar_mask.to(device)
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
# A bit of paranoia never hurts
assert self.nb_codes <= max_nb_codes
)
self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
self.device,
)
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
counts = F.one_hot(counts).sum(0)
logger(f"test_pop_stack_counts {counts}")
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
s = " ".join(seq)
logger(f"example_seq {s}")
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
self.train_input = self.tensorize(train_sequences)
self.test_input = self.tensorize(test_sequences)
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
for e in self.test_ref_test_errors:
f.write(f"{e}\n")
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", desc=None):
assert split in {"train", "test"}