X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=4777a11676447c4137683ec988bd011a0ad69d81;hb=e56873a0cb64555cbd47e44cdca0ce991765a5fc;hp=58638ed95ae343842ab810727b4aefd9fe0daabe;hpb=4395f9a90218819997c706de9505cda1c86ad507;p=mygptrnn.git diff --git a/tasks.py b/tasks.py index 58638ed..4777a11 100755 --- a/tasks.py +++ b/tasks.py @@ -58,7 +58,7 @@ def masked_inplace_autoregression( class Task: - def batches(self, split="train"): + def batches(self, split="train", desc=None): pass def vocabulary_size(self): @@ -328,7 +328,7 @@ class PicoCLVR(Task): self.train_input = self.tensorize(self.train_descr) self.test_input = self.tensorize(self.test_descr) - def batches(self, split="train"): + def batches(self, split="train", desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input for batch in tqdm.tqdm( @@ -1515,11 +1515,13 @@ class Grid(Task): self.train_input = self.str2tensor(self.train_descr) self.test_input = self.str2tensor(self.test_descr) - def batches(self, split="train"): + def batches(self, split="train", desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input + if desc is None: + desc = f"epoch-{split}" for batch in tqdm.tqdm( - input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}" + input.split(self.batch_size), dynamic_ncols=True, desc=desc ): yield self.trim(batch) @@ -1618,11 +1620,13 @@ class QMLP(Task): self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 - def batches(self, split="train"): + def batches(self, split="train", desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input + if desc is None: + desc = f"epoch-{split}" for batch in tqdm.tqdm( - input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}" + input.split(self.batch_size), dynamic_ncols=True, desc=desc ): yield batch