Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 19 Feb 2024 13:56:02 +0000 (14:56 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 19 Feb 2024 13:56:02 +0000 (14:56 +0100)
main.py
tasks.py

diff --git a/main.py b/main.py
index 55f2c2f..9198edc 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -89,7 +89,9 @@ parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
 ##############################
 # filetask
 
-parser.add_argument("--filetask_file", type=str, default=None)
+parser.add_argument("--filetask_train_file", type=str, default=None)
+
+parser.add_argument("--filetask_test_file", type=str, default=None)
 
 ##############################
 # rpl options
@@ -403,10 +405,11 @@ picoclvr_pruner_eval = (
 
 if args.task == "file":
     assert (
-        args.filetask_file is not None
-    ), "You have to specify the task file with --filetask_file <filename>"
+        args.filetask_train_file is not None and args.filetask_test_file is not None
+    ), "You have to specify the task train and test files"
     task = tasks.TaskFromFile(
-        args.filetask_file,
+        args.filetask_train_file,
+        args.filetask_test_file,
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
         batch_size=args.batch_size,
index 1ea3b5d..e5d3a7e 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -117,7 +117,8 @@ class TaskFromFile(Task):
 
     def __init__(
         self,
-        filename,
+        train_filename,
+        test_filename,
         nb_train_samples,
         nb_test_samples,
         batch_size,
@@ -126,26 +127,37 @@ class TaskFromFile(Task):
         self.batch_size = batch_size
         self.device = device
 
-        pairs = []
-        with open(filename, "r") as f:
-            for _ in range(nb_train_samples + nb_test_samples):
-                sequence = f.readline().strip()
-                pred_mask = f.readline().strip()
-                assert len(sequence) == len(pred_mask)
-                assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}"
-                pairs.append((sequence, pred_mask))
-
-        symbols = ["#"] + list(set("".join([x[0] for x in pairs])) - set(["#"]))
+        def read_file(filename, nb=-1):
+            pairs = []
+            with open(filename, "r") as f:
+                while True:
+                    sequence = f.readline().strip()
+                    if not sequence:
+                        break
+                    pred_mask = f.readline().strip()
+                    assert len(sequence) == len(pred_mask)
+                    assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}"
+                    pairs.append((sequence, pred_mask))
+                    if len(pairs) == nb:
+                        break
+
+            if nb > 0:
+                pairs = pairs[:nb]
+                assert len(pairs) == nb
+
+            return pairs
+
+        train_pairs = read_file(train_filename, nb_train_samples)
+        test_pairs = read_file(test_filename, nb_test_samples)
+
+        symbols = ["#"] + list(
+            set("".join([x[0] for x in train_pairs + test_pairs])) - set(["#"])
+        )
         self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
 
-        self.train_input, self.train_pred_masks = self.tensorize(
-            pairs[:nb_train_samples]
-        )
-        self.test_input, self.test_pred_masks = self.tensorize(pairs[nb_train_samples:])
-
-        assert self.train_input.size(0) == nb_train_samples
-        assert self.test_input.size(0) == nb_test_samples
+        self.train_input, self.train_pred_masks = self.tensorize(train_pairs)
+        self.test_input, self.test_pred_masks = self.tensorize(test_pairs)
 
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
@@ -176,7 +188,7 @@ class TaskFromFile(Task):
 
         logger(f"----------------------------------------------------------")
 
-        for e in self.tensor2str(result[:10]):
+        for e in self.tensor2str(result[:50]):
             logger(f"test_before {e}")
 
         masked_inplace_autoregression(
@@ -190,7 +202,7 @@ class TaskFromFile(Task):
 
         logger(f"----------------------------------------------------------")
 
-        for e, c in zip(self.tensor2str(result[:10]), self.tensor2str(correct[:10])):
+        for e, c in zip(self.tensor2str(result[:50]), self.tensor2str(correct[:50])):
             logger(f"test_after  {e}")
             logger(f"correct     {c}")