X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=pscan.py;h=0bb0d145bf9c6c82115956c8ce1e6a063e56e747;hb=037adb139441f40078421cd40f6aad1748c2724d;hp=a14f009c97a3c180c6b9ba539be77d8d0eb511b6;hpb=3c6931f8ddc8160550e026d9e9610ef71260ce10;p=mygptrnn.git diff --git a/pscan.py b/pscan.py index a14f009..0bb0d14 100755 --- a/pscan.py +++ b/pscan.py @@ -122,6 +122,22 @@ def naive_pscan(A, X, Y_init): if __name__ == "__main__": import time, sys + ###################################################################### + + N, T, D = 16, 4096, 32 + + for r in range(timing.size(0)): + A = 0.9 + 0.1 * torch.rand(N, T, dtype=torch.float64).requires_grad_() + X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() + Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_() + + start_time = time.perf_counter() + for _ in range(1000): + Y = pscan(A, X, Y_init) + duration = time.perf_counter() - start_time + + ###################################################################### + # A = torch.rand(17, 12, 3) # X = torch.rand(17, 12, 3, 11) # Y_init = torch.rand(17, 3, 11) @@ -130,11 +146,12 @@ if __name__ == "__main__": # exit(0) err = 0 + timing = torch.empty(10) - for _ in range(100): - N, T, D = 2, 112, 3 + for r in range(timing.size(0)): + N, T, D = 2, 1120, 3 - T = torch.randint(10, (1,)).item() + 1 + # T = torch.randint(10, (1,)).item() + 1 A = 0.9 + 0.1 * torch.rand(N, T, dtype=torch.float64).requires_grad_() X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_() @@ -161,7 +178,9 @@ if __name__ == "__main__": for _ in range(1000): Y = pscan(A, X, Y_init) duration = time.perf_counter() - start_time + print(f"duration {duration}") + timing[r] = duration s = Y.sum() @@ -181,4 +200,4 @@ if __name__ == "__main__": # print((Y - torch.cat([Y1, Y2], dim=1)).abs().max()) - print(f"{err=}") + print(f"err={err:.2e} duration={timing.mean():.2e} (+/- {timing.std():.2e})")