Update.
[mygptrnn.git] / tasks.py
index 58638ed..afad8af 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -58,7 +58,7 @@ def masked_inplace_autoregression(
 
 
 class Task:
-    def batches(self, split="train"):
+    def batches(self, split="train", desc=None):
         pass
 
     def vocabulary_size(self):
@@ -328,7 +328,7 @@ class PicoCLVR(Task):
         self.train_input = self.tensorize(self.train_descr)
         self.test_input = self.tensorize(self.test_descr)
 
-    def batches(self, split="train"):
+    def batches(self, split="train", desc=None):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input
         for batch in tqdm.tqdm(