X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=727b196b3a2f008854fb314389254589ab29d715;hb=a3c32b845b6903fd290f2b09d5c53203ff112b79;hp=afad8afbb83ab20d7b8f9da48be801ea14467298;hpb=a09ee76c8283b7daf4c914df47f86d1964fc25d4;p=mygptrnn.git diff --git a/tasks.py b/tasks.py index afad8af..727b196 100755 --- a/tasks.py +++ b/tasks.py @@ -1473,6 +1473,8 @@ class Grid(Task): nb_test_samples, batch_size, size, + nb_shapes, + nb_colors, logger=None, device=torch.device("cpu"), ): @@ -1480,7 +1482,9 @@ class Grid(Task): self.device = device self.batch_size = batch_size - self.grid_factory = grid.GridFactory(size=size) + self.grid_factory = grid.GridFactory( + size=size, nb_shapes=nb_shapes, nb_colors=nb_colors + ) if logger is not None: logger( @@ -1515,11 +1519,13 @@ class Grid(Task): self.train_input = self.str2tensor(self.train_descr) self.test_input = self.str2tensor(self.test_descr) - def batches(self, split="train"): + def batches(self, split="train", desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input + if desc is None: + desc = f"epoch-{split}" for batch in tqdm.tqdm( - input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}" + input.split(self.batch_size), dynamic_ncols=True, desc=desc ): yield self.trim(batch) @@ -1618,11 +1624,13 @@ class QMLP(Task): self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 - def batches(self, split="train"): + def batches(self, split="train", desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input + if desc is None: + desc = f"epoch-{split}" for batch in tqdm.tqdm( - input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}" + input.split(self.batch_size), dynamic_ncols=True, desc=desc ): yield batch