self.train_input = self.str2tensor(self.train_descr)
self.test_input = self.str2tensor(self.test_descr)
- def batches(self, split="train"):
+ def batches(self, split="train", desc=None):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
+ if desc is None:
+ desc = f"epoch-{split}"
for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
+ input.split(self.batch_size), dynamic_ncols=True, desc=desc
):
yield self.trim(batch)
self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
- def batches(self, split="train"):
+ def batches(self, split="train", desc=None):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
+ if desc is None:
+ desc = f"epoch-{split}"
for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
+ input.split(self.batch_size), dynamic_ncols=True, desc=desc
):
yield batch