def add_memex_v3(batches, memex_proba, marker_token):
for input in batches:
- if torch.rand(1).item() < memex_proba:
- memex_len = input.size(1) // 4
-
- t = torch.arange(input.size(1) + memex_len, device=input.device)[
- None, :
- ].expand(input.size(0), -1)
-
- # Call me the tensor-spaghetti master
-
- trigger = torch.rand(t.size(), device=t.device)
- trigger[:, -memex_len:] = 1.0
- trigger = (trigger.sort(dim=1).indices == 0).long()
- memex_mask = trigger.clone()
- memex_mask[:, memex_len:] -= memex_mask[:, :-memex_len]
- memex_mask = memex_mask.cumsum(dim=1)
- u = 1 - memex_mask
- u[:, 0] = 0
- u = u.cumsum(dim=1)
- # assert u.min() == 0
- # assert u.max() == input.size(1) - 1
- v = (
- (trigger.cumsum(dim=1) - trigger).cumsum(dim=1)
- + torch.randint(input.size(1), (input.size(0), 1), device=t.device)
- ) * memex_mask
- u = u * (1 - memex_mask) + v * memex_mask
- n = torch.arange(input.size(0), device=input.device)[:, None].expand(
- -1, t.size(1)
+ memex_len = input.size(1) // 8
+
+ t = torch.arange(input.size(1) + memex_len, device=input.device)[
+ None, :
+ ].expand(input.size(0), -1)
+ n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+ -1, t.size(1)
+ )
+
+ t = (t - 1).clamp(min=0)
+
+ # Call me the tensor-spaghetti master
+
+ trigger = torch.rand(t.size(), device=t.device)
+ trigger[:, -memex_len:] = 2.0
+ trigger[:, : memex_len + 1] = 2.0
+ trigger = (trigger == trigger.min(dim=1, keepdim=True).values).long()
+ memex_mask = trigger.clone()
+ memex_mask[:, memex_len:] -= trigger[:, :-memex_len]
+ memex_mask = memex_mask.cumsum(dim=1)
+
+ u = 1 - memex_mask
+ u[:, 0] = 0
+ u = u.cumsum(dim=1)
+
+ v = (
+ (trigger.cumsum(dim=1) - trigger).cumsum(dim=1)
+ + torch.randint(
+ input.size(1) - memex_len, (input.size(0), 1), device=t.device
)
- new_input = input[n, u]
- limits = trigger.clone()
- limits[:, memex_len - 1 :] += limits[:, : -(memex_len - 1)]
- new_input = new_input * (1 - limits) + memex_marker * limits
+ ) * memex_mask
+ u = u * (1 - memex_mask) + v * memex_mask
+
+ new_input = input[n, u]
+ limits = trigger.clone()
+ limits[:, memex_len - 1 :] += limits[:, : -(memex_len - 1)]
+ new_input = new_input * (1 - limits) + marker_token * limits
+ new_input[:, 0] = marker_token
+
+ orig = torch.cat(
+ [
+ input,
+ torch.full((input.size(0), memex_len), memex_marker, device=t.device),
+ ],
+ dim=1,
+ )
- yield new_input, memex_mask
+ a = (torch.rand(input.size(0), 1, device=t.device) <= memex_proba).long()
- else:
- yield input
+ new_input = (1 - a) * orig + a * new_input
+
+ yield new_input # memex_mask
######################################################################
log_string(f"memex_proba {memex_proba}")
- warnings.warn("memex v3", RuntimeWarning)
- train_batches = add_memex_v3(
- batches=task.batches(split="train"),
- memex_proba=memex_proba,
- marker_token=memex_marker,
- )
+ if args.memex_proba > 0:
+ warnings.warn("memex v3", RuntimeWarning)
+ train_batches = add_memex_v3(
+ batches=task.batches(split="train"),
+ memex_proba=memex_proba,
+ marker_token=memex_marker,
+ )
+ else:
+ train_batches = task.batches(split="train")
def add_none(it):
for x in it:
optimizer.step()
grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
- lambda_file.write(
- f"{n_epoch} {n_batch} {l_memex} {norm_regular} {norm_memex}\n"
- )
+ if memex_mask is not None:
+ lambda_file.write(
+ f"{n_epoch} {n_batch} {l_memex} {norm_regular} {norm_memex}\n"
+ )
optimizer.zero_grad()
nb_acc_samples = 0
n_batch += 1
return sequences, ar_mask
def seq2str(self, seq):
- return "".join(self.token_string[x.item()] for x in seq)
+ def decode(x):
+ if x < len(self.token_string):
+ return self.token_string[x]
+ else:
+ return "?"
+
+ return "".join(decode(x.item()) for x in seq)
class ProblemTwoTargets(Problem):
######################################################################
- N, T, D = 16, 4096, 32
+ # N, T, D = 16, 4096, 32
- for r in range(timing.size(0)):
- A = 0.9 + 0.1 * torch.rand(N, T, dtype=torch.float64).requires_grad_()
- X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
- Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_()
+ # for r in range(timing.size(0)):
+ # A = 0.9 + 0.1 * torch.rand(N, T, dtype=torch.float64).requires_grad_()
+ # X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
+ # Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_()
- start_time = time.perf_counter()
- for _ in range(1000):
- Y = pscan(A, X, Y_init)
- duration = time.perf_counter() - start_time
+ # start_time = time.perf_counter()
+ # for _ in range(1000):
+ # Y = pscan(A, X, Y_init)
+ # duration = time.perf_counter() - start_time
######################################################################
device
), self.test_ar_mask.to(device)
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
# A bit of paranoia never hurts
assert self.nb_codes <= max_nb_codes
)
self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
self.device,
)
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
counts = F.one_hot(counts).sum(0)
logger(f"test_pop_stack_counts {counts}")
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
s = " ".join(seq)
logger(f"example_seq {s}")
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
self.train_input = self.tensorize(train_sequences)
self.test_input = self.tensorize(test_sequences)
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
for e in self.test_ref_test_errors:
f.write(f"{e}\n")
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = (max(self.train_input.max(), self.test_input.max()) + 1).item()
def batches(self, split="train", desc=None):
assert split in {"train", "test"}