From c0750e416e28fbdc9f6dc03cc6d7b11edd1ac333 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 6 Jan 2024 19:07:35 +0100 Subject: [PATCH] Update. --- mygpt.py | 2 +- pscan.py | 169 +++++++++++++++++++++++++++++++++++-------------------- 2 files changed, 108 insertions(+), 63 deletions(-) diff --git a/mygpt.py b/mygpt.py index c061eb4..4d48247 100755 --- a/mygpt.py +++ b/mygpt.py @@ -563,7 +563,7 @@ class Caterpillar(nn.Module): # by updating that at time t-L, the parallel scan operates # with a period of L. To do so we split the time indexing in # two axes, the second of size CL, and run the parallel scan - # using the other alone as the sequence index. + # using the other as the sequence index. A = A.unflatten(2, (-1, CL)) gated_V = gated_V.unflatten(2, (-1, CL)) diff --git a/pscan.py b/pscan.py index 0ec7b13..88cb3d5 100755 --- a/pscan.py +++ b/pscan.py @@ -11,8 +11,8 @@ import torch class PScan(torch.autograd.Function): - # Given A is NxTx1 and X is NxTxD, expands A and X in place in O(T), - # and O(log(T)) if not core-bounded, so that + # Given A is NxTxMx1 and X is NxTxMxD, expands A and X in + # place in O(T), and O(log(T)) if not core-bounded, so that # # Y[:, 0] = Y_init # Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t] @@ -23,33 +23,57 @@ class PScan(torch.autograd.Function): @staticmethod def expand_(A, X): - if A.size(1) == 1: - return - T = 2 * (A.size(1) // 2) - Aa = A[:, :T].view(A.size(0), T // 2, 2, -1, 1) - Xa = X[:, :T].view(X.size(0), T // 2, 2, -1, X.size(-1)) - Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0])) - Aa[:, :, 1].mul_(Aa[:, :, 0]) - PScan.expand_(Aa[:, :, 1], Xa[:, :, 1]) - Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1])) - Aa[:, 1:, 0].mul_(Aa[:, :-1, 1]) - if T < A.size(1): - X[:, -1].add_(A[:, -1].mul(X[:, -2])) - A[:, -1].mul_(A[:, -2]) + # Unrolling gains ~8% speed + + if A.size(1) > 4: + T = 2 * (A.size(1) // 2) + Aa = A[:, :T].view(A.size(0), T // 2, 2, -1, 1) + Xa = X[:, :T].view(X.size(0), T // 2, 2, -1, X.size(-1)) + Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0])) + Aa[:, :, 1].mul_(Aa[:, :, 0]) + PScan.expand_(Aa[:, :, 1], Xa[:, :, 1]) + Xa[:, 1:, 0].add_(Aa[:, 1:, 0].mul(Xa[:, :-1, 1])) + Aa[:, 1:, 0].mul_(Aa[:, :-1, 1]) + if T < A.size(1): + X[:, -1].add_(A[:, -1].mul(X[:, -2])) + A[:, -1].mul_(A[:, -2]) + elif A.size(1) == 2: + X[:, 1].add_(A[:, 1].mul(X[:, 0])) + A[:, 1].mul_(A[:, 0]) + elif A.size(1) == 3: + X[:, 1].add_(A[:, 1].mul(X[:, 0])) + A[:, 1].mul_(A[:, 0]) + X[:, 2].add_(A[:, 2].mul(X[:, 1])) + A[:, 2].mul_(A[:, 1]) + elif A.size(1) == 4: + X[:, 1].add_(A[:, 1].mul(X[:, 0])) + A[:, 1].mul_(A[:, 0]) + X[:, 2].add_(A[:, 2].mul(X[:, 1])) + A[:, 2].mul_(A[:, 1]) + X[:, 3].add_(A[:, 3].mul(X[:, 2])) + A[:, 3].mul_(A[:, 2]) @staticmethod def acc_rev_(A, X): - if X.size(1) == 1: - return - T = 2 * (X.size(1) // 2) - Aa = A[:, -T:].view(A.size(0), T // 2, 2, -1, 1) - Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1, X.size(-1)) - Xa[:, :, 0].add_(Aa[:, :, 1].mul(Xa[:, :, 1])) - B = Aa[:, :, 0].clone() - B[:, 1:].mul_(Aa[:, :-1, 1]) - PScan.acc_rev_(B, Xa[:, :, 0]) - Xa[:, :-1, 1].add_(Aa[:, 1:, 0].mul(Xa[:, 1:, 0])) - if T < A.size(1): + if A.size(1) > 4: + T = 2 * (X.size(1) // 2) + Aa = A[:, -T:].view(A.size(0), T // 2, 2, -1, 1) + Xa = X[:, -T:].view(X.size(0), T // 2, 2, -1, X.size(-1)) + Xa[:, :, 0].add_(Aa[:, :, 1].mul(Xa[:, :, 1])) + B = Aa[:, :, 0].clone() + B[:, 1:].mul_(Aa[:, :-1, 1]) + PScan.acc_rev_(B, Xa[:, :, 0]) + Xa[:, :-1, 1].add_(Aa[:, 1:, 0].mul(Xa[:, 1:, 0])) + if T < A.size(1): + X[:, 0].add_(A[:, 1].mul(X[:, 1])) + elif A.size(1) == 2: + X[:, 0].add_(A[:, 1].mul(X[:, 1])) + elif A.size(1) == 3: + X[:, 1].add_(A[:, 2].mul(X[:, 2])) + X[:, 0].add_(A[:, 1].mul(X[:, 1])) + elif A.size(1) == 4: + X[:, 2].add_(A[:, 3].mul(X[:, 3])) + X[:, 1].add_(A[:, 2].mul(X[:, 2])) X[:, 0].add_(A[:, 1].mul(X[:, 1])) # A is NxT, X is NxTxD, Y_init is NxD @@ -81,59 +105,80 @@ class PScan(torch.autograd.Function): pscan = PScan.apply + +def naive_pscan(A, X, Y_init): + y = Y_init + s = 0 + + for k in range(A.size(1)): + y = A[:, k, None] * y + X[:, k] + s = s + y + + s = s.sum() + + ###################################################################### if __name__ == "__main__": import time, sys - A = torch.rand(17, 12, 3) - X = torch.rand(17, 12, 3, 11) - Y_init = torch.rand(17, 3, 11) - Y = pscan(A, X, Y_init) - exit(0) + # A = torch.rand(17, 12, 3) + # X = torch.rand(17, 12, 3, 11) + # Y_init = torch.rand(17, 3, 11) + # Y = pscan(A, X, Y_init) - N, T, D = 2, 1047, 3 + # exit(0) - A = torch.rand(N, T, dtype=torch.float64).requires_grad_() - X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() - Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_() + err = 0 - # Iterative implementation + for _ in range(100): + N, T, D = 2, 112, 3 - y = Y_init - s = 0 + T = torch.randint(10, (1,)).item() + 1 - for k in range(A.size(1)): - y = A[:, k, None] * y + X[:, k] - s = s + y + A = 0.9 + 0.1 * torch.rand(N, T, dtype=torch.float64).requires_grad_() + X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() + Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_() - s = s.sum() + # Iterative implementation + + y = Y_init + s = 0 + + for k in range(A.size(1)): + y = A[:, k, None] * y + X[:, k] + s = s + y + + s = s.sum() - gA_ref, gX_ref, gY_init_ref = torch.autograd.grad( - s, (A, X, Y_init), retain_graph=True - ) + gA_ref, gX_ref, gY_init_ref = torch.autograd.grad( + s, (A, X, Y_init), retain_graph=True + ) - # parallel scan + # parallel scan - start_time = time.perf_counter() - for _ in range(1000): - Y = pscan(A, X, Y_init) - duration = time.perf_counter() - start_time - print(f"duration {duration}") + start_time = time.perf_counter() + for _ in range(1000): + Y = pscan(A, X, Y_init) + duration = time.perf_counter() - start_time + print(f"duration {duration}") - s = Y.sum() + s = Y.sum() - gA, gX, gY_init = torch.autograd.grad(s, (A, X, Y_init), retain_graph=True) + gA, gX, gY_init = torch.autograd.grad(s, (A, X, Y_init), retain_graph=True) - # print(gA) - # print(gX) - # print(gY_init) + err = max( + [ + err, + (gA - gA_ref).abs().max(), + (gX - gX_ref).abs().max(), + (gY_init - gY_init_ref).abs().max(), + ] + ) - print((gA - gA_ref).norm()) - print((gX - gX_ref).norm()) - print((gY_init - gY_init_ref).norm()) + # Y1 = pscan(A[:, : T // 2], X[:, : T // 2], Y_init) + # Y2 = pscan(A[:, T // 2 :], X[:, T // 2 :], Y1[:, -1]) - Y1 = pscan(A[:, : T // 2], X[:, : T // 2], Y_init) - Y2 = pscan(A[:, T // 2 :], X[:, T // 2 :], Y1[:, -1]) + # print((Y - torch.cat([Y1, Y2], dim=1)).abs().max()) - print((Y - torch.cat([Y1, Y2], dim=1)).norm()) +print(f"{err=}") -- 2.20.1