Update.
[culture.git] / tasks.py
index 181ac44..1ea3b5d 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -132,7 +132,7 @@ class TaskFromFile(Task):
                 sequence = f.readline().strip()
                 pred_mask = f.readline().strip()
                 assert len(sequence) == len(pred_mask)
                 sequence = f.readline().strip()
                 pred_mask = f.readline().strip()
                 assert len(sequence) == len(pred_mask)
-                assert set(pred_mask) == {"0", "1", "2"}, f"{set(pred_mask)}"
+                assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}"
                 pairs.append((sequence, pred_mask))
 
         symbols = ["#"] + list(set("".join([x[0] for x in pairs])) - set(["#"]))
                 pairs.append((sequence, pred_mask))
 
         symbols = ["#"] + list(set("".join([x[0] for x in pairs])) - set(["#"]))
@@ -144,6 +144,9 @@ class TaskFromFile(Task):
         )
         self.test_input, self.test_pred_masks = self.tensorize(pairs[nb_train_samples:])
 
         )
         self.test_input, self.test_pred_masks = self.tensorize(pairs[nb_train_samples:])
 
+        assert self.train_input.size(0) == nb_train_samples
+        assert self.test_input.size(0) == nb_test_samples
+
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input