Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 20 Feb 2024 08:50:44 +0000 (09:50 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 20 Feb 2024 08:50:44 +0000 (09:50 +0100)
main.py
problems.py
tasks.py

diff --git a/main.py b/main.py
index 00b8301..a587e96 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -604,32 +604,46 @@ def add_memex_v3(batches, memex_proba, marker_token):
             t = torch.arange(input.size(1) + memex_len, device=input.device)[
                 None, :
             ].expand(input.size(0), -1)
+            n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+                -1, t.size(1)
+            )
 
             # Call me the tensor-spaghetti master
 
             trigger = torch.rand(t.size(), device=t.device)
-            trigger[:, -memex_len:] = 1.0
-            trigger = (trigger.sort(dim=1).indices == 0).long()
+            trigger[:, -memex_len:] = 2.0
+            trigger[:, 0] = 2.0
+            trigger = (trigger == trigger.min(dim=1, keepdim=True).values).long()
             memex_mask = trigger.clone()
-            memex_mask[:, memex_len:] -= memex_mask[:, :-memex_len]
+            memex_mask[:, memex_len:] -= trigger[:, :-memex_len]
             memex_mask = memex_mask.cumsum(dim=1)
+
             u = 1 - memex_mask
             u[:, 0] = 0
             u = u.cumsum(dim=1)
-            # assert u.min() == 0
-            # assert u.max() == input.size(1) - 1
+            assert u.min() == 0
+            assert u.max() == input.size(1) - 1
+
             v = (
                 (trigger.cumsum(dim=1) - trigger).cumsum(dim=1)
-                + torch.randint(input.size(1), (input.size(0), 1), device=t.device)
+                + torch.randint(
+                    input.size(1) - memex_len, (input.size(0), 1), device=t.device
+                )
             ) * memex_mask
+            assert v.min() >= 0
+            assert v.max() < input.size(1)
             u = u * (1 - memex_mask) + v * memex_mask
-            n = torch.arange(input.size(0), device=input.device)[:, None].expand(
-                -1, t.size(1)
-            )
+
             new_input = input[n, u]
+            assert input.max() < vocabulary_size
+            assert new_input.max() < vocabulary_size
             limits = trigger.clone()
             limits[:, memex_len - 1 :] += limits[:, : -(memex_len - 1)]
-            new_input = new_input * (1 - limits) + memex_marker * limits
+            assert limits.min() == 0
+            assert limits.max() == 1
+            new_input = new_input * (1 - limits) + marker_token * limits
+            assert marker_token < vocabulary_size
+            assert new_input.max() < vocabulary_size
 
             yield new_input, memex_mask
 
index 9e368c2..3cdd374 100755 (executable)
@@ -149,7 +149,13 @@ class ProblemMemory(Problem):
         return sequences, ar_mask
 
     def seq2str(self, seq):
-        return "".join(self.token_string[x.item()] for x in seq)
+        def decode(x):
+            if x < len(self.token_string):
+                return self.token_string[x]
+            else:
+                return "?"
+
+        return "".join(decode(x.item()) for x in seq)
 
 
 class ProblemTwoTargets(Problem):
index 218ff36..57c6801 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -106,7 +106,7 @@ class SandBox(Task):
             device
         ), self.test_ar_mask.to(device)
 
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+        self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
 
         # A bit of paranoia never hurts
         assert self.nb_codes <= max_nb_codes
@@ -579,7 +579,7 @@ class Maze(Task):
         )
         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
 
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+        self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
 
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
@@ -756,7 +756,7 @@ class Snake(Task):
             self.device,
         )
 
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+        self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
 
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
@@ -871,7 +871,7 @@ class Stack(Task):
         counts = F.one_hot(counts).sum(0)
         logger(f"test_pop_stack_counts {counts}")
 
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+        self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
 
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
@@ -1078,7 +1078,7 @@ class RPL(Task):
                 s = " ".join(seq)
                 logger(f"example_seq {s}")
 
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+        self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
 
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
@@ -1308,7 +1308,7 @@ class Expr(Task):
         self.train_input = self.tensorize(train_sequences)
         self.test_input = self.tensorize(test_sequences)
 
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+        self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
 
     def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
@@ -1639,7 +1639,7 @@ class QMLP(Task):
             for e in self.test_ref_test_errors:
                 f.write(f"{e}\n")
 
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+        self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
 
     def batches(self, split="train", desc=None):
         assert split in {"train", "test"}