X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=tasks.py;h=1ea3b5d588d37e4c89ca4d06f74d20356b7a2fcf;hb=4f489998d6e73680c3a031e8932a7678c16268e3;hp=181ac4475a2126df33206e2a4cea29a3f2d57833;hpb=732349f7c16e43ff84380d28e021d671f2c56492;p=culture.git diff --git a/tasks.py b/tasks.py index 181ac44..1ea3b5d 100755 --- a/tasks.py +++ b/tasks.py @@ -132,7 +132,7 @@ class TaskFromFile(Task): sequence = f.readline().strip() pred_mask = f.readline().strip() assert len(sequence) == len(pred_mask) - assert set(pred_mask) == {"0", "1", "2"}, f"{set(pred_mask)}" + assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}" pairs.append((sequence, pred_mask)) symbols = ["#"] + list(set("".join([x[0] for x in pairs])) - set(["#"])) @@ -144,6 +144,9 @@ class TaskFromFile(Task): ) self.test_input, self.test_pred_masks = self.tensorize(pairs[nb_train_samples:]) + assert self.train_input.size(0) == nb_train_samples + assert self.test_input.size(0) == nb_test_samples + def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input