X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=afad8afbb83ab20d7b8f9da48be801ea14467298;hb=a09ee76c8283b7daf4c914df47f86d1964fc25d4;hp=58638ed95ae343842ab810727b4aefd9fe0daabe;hpb=cb737bdbd2f112826f739e4581fbe6546aeef638;p=mygptrnn.git diff --git a/tasks.py b/tasks.py index 58638ed..afad8af 100755 --- a/tasks.py +++ b/tasks.py @@ -58,7 +58,7 @@ def masked_inplace_autoregression( class Task: - def batches(self, split="train"): + def batches(self, split="train", desc=None): pass def vocabulary_size(self): @@ -328,7 +328,7 @@ class PicoCLVR(Task): self.train_input = self.tensorize(self.train_descr) self.test_input = self.tensorize(self.test_descr) - def batches(self, split="train"): + def batches(self, split="train", desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input for batch in tqdm.tqdm(