From 9d4193312e06ed284b1368b7f4407f2b4f981c7a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 20 Feb 2024 09:50:44 +0100 Subject: [PATCH] Update. --- main.py | 34 ++++++++++++++++++++++++---------- problems.py | 8 +++++++- tasks.py | 14 +++++++------- 3 files changed, 38 insertions(+), 18 deletions(-) diff --git a/main.py b/main.py index 00b8301..a587e96 100755 --- a/main.py +++ b/main.py @@ -604,32 +604,46 @@ def add_memex_v3(batches, memex_proba, marker_token): t = torch.arange(input.size(1) + memex_len, device=input.device)[ None, : ].expand(input.size(0), -1) + n = torch.arange(input.size(0), device=input.device)[:, None].expand( + -1, t.size(1) + ) # Call me the tensor-spaghetti master trigger = torch.rand(t.size(), device=t.device) - trigger[:, -memex_len:] = 1.0 - trigger = (trigger.sort(dim=1).indices == 0).long() + trigger[:, -memex_len:] = 2.0 + trigger[:, 0] = 2.0 + trigger = (trigger == trigger.min(dim=1, keepdim=True).values).long() memex_mask = trigger.clone() - memex_mask[:, memex_len:] -= memex_mask[:, :-memex_len] + memex_mask[:, memex_len:] -= trigger[:, :-memex_len] memex_mask = memex_mask.cumsum(dim=1) + u = 1 - memex_mask u[:, 0] = 0 u = u.cumsum(dim=1) - # assert u.min() == 0 - # assert u.max() == input.size(1) - 1 + assert u.min() == 0 + assert u.max() == input.size(1) - 1 + v = ( (trigger.cumsum(dim=1) - trigger).cumsum(dim=1) - + torch.randint(input.size(1), (input.size(0), 1), device=t.device) + + torch.randint( + input.size(1) - memex_len, (input.size(0), 1), device=t.device + ) ) * memex_mask + assert v.min() >= 0 + assert v.max() < input.size(1) u = u * (1 - memex_mask) + v * memex_mask - n = torch.arange(input.size(0), device=input.device)[:, None].expand( - -1, t.size(1) - ) + new_input = input[n, u] + assert input.max() < vocabulary_size + assert new_input.max() < vocabulary_size limits = trigger.clone() limits[:, memex_len - 1 :] += limits[:, : -(memex_len - 1)] - new_input = new_input * (1 - limits) + memex_marker * limits + assert limits.min() == 0 + assert limits.max() == 1 + new_input = new_input * (1 - limits) + marker_token * limits + assert marker_token < vocabulary_size + assert new_input.max() < vocabulary_size yield new_input, memex_mask diff --git a/problems.py b/problems.py index 9e368c2..3cdd374 100755 --- a/problems.py +++ b/problems.py @@ -149,7 +149,13 @@ class ProblemMemory(Problem): return sequences, ar_mask def seq2str(self, seq): - return "".join(self.token_string[x.item()] for x in seq) + def decode(x): + if x < len(self.token_string): + return self.token_string[x] + else: + return "?" + + return "".join(decode(x.item()) for x in seq) class ProblemTwoTargets(Problem): diff --git a/tasks.py b/tasks.py index 218ff36..57c6801 100755 --- a/tasks.py +++ b/tasks.py @@ -106,7 +106,7 @@ class SandBox(Task): device ), self.test_ar_mask.to(device) - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item() # A bit of paranoia never hurts assert self.nb_codes <= max_nb_codes @@ -579,7 +579,7 @@ class Maze(Task): ) self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device)) - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item() def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -756,7 +756,7 @@ class Snake(Task): self.device, ) - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item() def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -871,7 +871,7 @@ class Stack(Task): counts = F.one_hot(counts).sum(0) logger(f"test_pop_stack_counts {counts}") - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item() def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -1078,7 +1078,7 @@ class RPL(Task): s = " ".join(seq) logger(f"example_seq {s}") - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item() def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -1308,7 +1308,7 @@ class Expr(Task): self.train_input = self.tensorize(train_sequences) self.test_input = self.tensorize(test_sequences) - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item() def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -1639,7 +1639,7 @@ class QMLP(Task): for e in self.test_ref_test_errors: f.write(f"{e}\n") - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item() def batches(self, split="train", desc=None): assert split in {"train", "test"} -- 2.39.5