From d0647484846de3985ed92d36ef554c47e2898902 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 31 Aug 2024 09:50:05 +0200 Subject: [PATCH] Update. --- main.py | 111 +++++++++++++++++++++++++++++-------------------------- pscan.py | 18 ++++----- 2 files changed, 67 insertions(+), 62 deletions(-) diff --git a/main.py b/main.py index a587e96..88f56b3 100755 --- a/main.py +++ b/main.py @@ -598,57 +598,58 @@ def add_memex_v2(batches, memex_proba, marker_token): def add_memex_v3(batches, memex_proba, marker_token): for input in batches: - if torch.rand(1).item() < memex_proba: - memex_len = input.size(1) // 4 + memex_len = input.size(1) // 8 - t = torch.arange(input.size(1) + memex_len, device=input.device)[ - None, : - ].expand(input.size(0), -1) - n = torch.arange(input.size(0), device=input.device)[:, None].expand( - -1, t.size(1) - ) + t = torch.arange(input.size(1) + memex_len, device=input.device)[ + None, : + ].expand(input.size(0), -1) + n = torch.arange(input.size(0), device=input.device)[:, None].expand( + -1, t.size(1) + ) - # Call me the tensor-spaghetti master + t = (t - 1).clamp(min=0) - trigger = torch.rand(t.size(), device=t.device) - trigger[:, -memex_len:] = 2.0 - trigger[:, 0] = 2.0 - trigger = (trigger == trigger.min(dim=1, keepdim=True).values).long() - memex_mask = trigger.clone() - memex_mask[:, memex_len:] -= trigger[:, :-memex_len] - memex_mask = memex_mask.cumsum(dim=1) + # Call me the tensor-spaghetti master - u = 1 - memex_mask - u[:, 0] = 0 - u = u.cumsum(dim=1) - assert u.min() == 0 - assert u.max() == input.size(1) - 1 + trigger = torch.rand(t.size(), device=t.device) + trigger[:, -memex_len:] = 2.0 + trigger[:, : memex_len + 1] = 2.0 + trigger = (trigger == trigger.min(dim=1, keepdim=True).values).long() + memex_mask = trigger.clone() + memex_mask[:, memex_len:] -= trigger[:, :-memex_len] + memex_mask = memex_mask.cumsum(dim=1) - v = ( - (trigger.cumsum(dim=1) - trigger).cumsum(dim=1) - + torch.randint( - input.size(1) - memex_len, (input.size(0), 1), device=t.device - ) - ) * memex_mask - assert v.min() >= 0 - assert v.max() < input.size(1) - u = u * (1 - memex_mask) + v * memex_mask - - new_input = input[n, u] - assert input.max() < vocabulary_size - assert new_input.max() < vocabulary_size - limits = trigger.clone() - limits[:, memex_len - 1 :] += limits[:, : -(memex_len - 1)] - assert limits.min() == 0 - assert limits.max() == 1 - new_input = new_input * (1 - limits) + marker_token * limits - assert marker_token < vocabulary_size - assert new_input.max() < vocabulary_size + u = 1 - memex_mask + u[:, 0] = 0 + u = u.cumsum(dim=1) - yield new_input, memex_mask + v = ( + (trigger.cumsum(dim=1) - trigger).cumsum(dim=1) + + torch.randint( + input.size(1) - memex_len, (input.size(0), 1), device=t.device + ) + ) * memex_mask + u = u * (1 - memex_mask) + v * memex_mask + + new_input = input[n, u] + limits = trigger.clone() + limits[:, memex_len - 1 :] += limits[:, : -(memex_len - 1)] + new_input = new_input * (1 - limits) + marker_token * limits + new_input[:, 0] = marker_token + + orig = torch.cat( + [ + input, + torch.full((input.size(0), memex_len), memex_marker, device=t.device), + ], + dim=1, + ) - else: - yield input + a = (torch.rand(input.size(0), 1, device=t.device) <= memex_proba).long() + + new_input = (1 - a) * orig + a * new_input + + yield new_input # memex_mask ###################################################################### @@ -1068,12 +1069,15 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): log_string(f"memex_proba {memex_proba}") - warnings.warn("memex v3", RuntimeWarning) - train_batches = add_memex_v3( - batches=task.batches(split="train"), - memex_proba=memex_proba, - marker_token=memex_marker, - ) + if args.memex_proba > 0: + warnings.warn("memex v3", RuntimeWarning) + train_batches = add_memex_v3( + batches=task.batches(split="train"), + memex_proba=memex_proba, + marker_token=memex_marker, + ) + else: + train_batches = task.batches(split="train") def add_none(it): for x in it: @@ -1140,9 +1144,10 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): optimizer.step() grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt() loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n") - lambda_file.write( - f"{n_epoch} {n_batch} {l_memex} {norm_regular} {norm_memex}\n" - ) + if memex_mask is not None: + lambda_file.write( + f"{n_epoch} {n_batch} {l_memex} {norm_regular} {norm_memex}\n" + ) optimizer.zero_grad() nb_acc_samples = 0 n_batch += 1 diff --git a/pscan.py b/pscan.py index 0bb0d14..b533164 100755 --- a/pscan.py +++ b/pscan.py @@ -124,17 +124,17 @@ if __name__ == "__main__": ###################################################################### - N, T, D = 16, 4096, 32 + # N, T, D = 16, 4096, 32 - for r in range(timing.size(0)): - A = 0.9 + 0.1 * torch.rand(N, T, dtype=torch.float64).requires_grad_() - X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() - Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_() + # for r in range(timing.size(0)): + # A = 0.9 + 0.1 * torch.rand(N, T, dtype=torch.float64).requires_grad_() + # X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() + # Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_() - start_time = time.perf_counter() - for _ in range(1000): - Y = pscan(A, X, Y_init) - duration = time.perf_counter() - start_time + # start_time = time.perf_counter() + # for _ in range(1000): + # Y = pscan(A, X, Y_init) + # duration = time.perf_counter() - start_time ###################################################################### -- 2.39.5