t = torch.arange(input.size(1) + memex_len, device=input.device)[
None, :
].expand(input.size(0), -1)
+ n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+ -1, t.size(1)
+ )
# Call me the tensor-spaghetti master
trigger = torch.rand(t.size(), device=t.device)
- trigger[:, -memex_len:] = 1.0
- trigger = (trigger.sort(dim=1).indices == 0).long()
+ trigger[:, -memex_len:] = 2.0
+ trigger[:, 0] = 2.0
+ trigger = (trigger == trigger.min(dim=1, keepdim=True).values).long()
memex_mask = trigger.clone()
- memex_mask[:, memex_len:] -= memex_mask[:, :-memex_len]
+ memex_mask[:, memex_len:] -= trigger[:, :-memex_len]
memex_mask = memex_mask.cumsum(dim=1)
+
u = 1 - memex_mask
u[:, 0] = 0
u = u.cumsum(dim=1)
- # assert u.min() == 0
- # assert u.max() == input.size(1) - 1
+ assert u.min() == 0
+ assert u.max() == input.size(1) - 1
+
v = (
(trigger.cumsum(dim=1) - trigger).cumsum(dim=1)
- + torch.randint(input.size(1), (input.size(0), 1), device=t.device)
+ + torch.randint(
+ input.size(1) - memex_len, (input.size(0), 1), device=t.device
+ )
) * memex_mask
+ assert v.min() >= 0
+ assert v.max() < input.size(1)
u = u * (1 - memex_mask) + v * memex_mask
- n = torch.arange(input.size(0), device=input.device)[:, None].expand(
- -1, t.size(1)
- )
+
new_input = input[n, u]
+ assert input.max() < vocabulary_size
+ assert new_input.max() < vocabulary_size
limits = trigger.clone()
limits[:, memex_len - 1 :] += limits[:, : -(memex_len - 1)]
- new_input = new_input * (1 - limits) + memex_marker * limits
+ assert limits.min() == 0
+ assert limits.max() == 1
+ new_input = new_input * (1 - limits) + marker_token * limits
+ assert marker_token < vocabulary_size
+ assert new_input.max() < vocabulary_size
yield new_input, memex_mask
device
), self.test_ar_mask.to(device)
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
# A bit of paranoia never hurts
assert self.nb_codes <= max_nb_codes
)
self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
self.device,
)
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
counts = F.one_hot(counts).sum(0)
logger(f"test_pop_stack_counts {counts}")
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
s = " ".join(seq)
logger(f"example_seq {s}")
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
self.train_input = self.tensorize(train_sequences)
self.test_input = self.tensorize(test_sequences)
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
for e in self.test_ref_test_errors:
f.write(f"{e}\n")
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", desc=None):
assert split in {"train", "test"}