sequence = f.readline().strip()
pred_mask = f.readline().strip()
assert len(sequence) == len(pred_mask)
- assert set(pred_mask) == {"0", "1", "2"}, f"{set(pred_mask)}"
+ assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}"
pairs.append((sequence, pred_mask))
symbols = ["#"] + list(set("".join([x[0] for x in pairs])) - set(["#"]))
)
self.test_input, self.test_pred_masks = self.tensorize(pairs[nb_train_samples:])
+ assert self.train_input.size(0) == nb_train_samples
+ assert self.test_input.size(0) == nb_test_samples
+
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input