From 4f489998d6e73680c3a031e8932a7678c16268e3 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 18 Feb 2024 23:30:01 +0100 Subject: [PATCH] Update. --- tasks.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tasks.py b/tasks.py index 78910a0..1ea3b5d 100755 --- a/tasks.py +++ b/tasks.py @@ -144,6 +144,9 @@ class TaskFromFile(Task): ) self.test_input, self.test_pred_masks = self.tensorize(pairs[nb_train_samples:]) + assert self.train_input.size(0) == nb_train_samples + assert self.test_input.size(0) == nb_test_samples + def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input -- 2.39.5