X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=4777a11676447c4137683ec988bd011a0ad69d81;hb=e56873a0cb64555cbd47e44cdca0ce991765a5fc;hp=afad8afbb83ab20d7b8f9da48be801ea14467298;hpb=3dd98b99909b2bca323673263874e2abb39ac10c;p=mygptrnn.git diff --git a/tasks.py b/tasks.py index afad8af..4777a11 100755 --- a/tasks.py +++ b/tasks.py @@ -1515,11 +1515,13 @@ class Grid(Task): self.train_input = self.str2tensor(self.train_descr) self.test_input = self.str2tensor(self.test_descr) - def batches(self, split="train"): + def batches(self, split="train", desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input + if desc is None: + desc = f"epoch-{split}" for batch in tqdm.tqdm( - input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}" + input.split(self.batch_size), dynamic_ncols=True, desc=desc ): yield self.trim(batch) @@ -1618,11 +1620,13 @@ class QMLP(Task): self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 - def batches(self, split="train"): + def batches(self, split="train", desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input + if desc is None: + desc = f"epoch-{split}" for batch in tqdm.tqdm( - input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}" + input.split(self.batch_size), dynamic_ncols=True, desc=desc ): yield batch