+ ######################################################################
+
+ N, T, D = 16, 4096, 32
+
+ for r in range(timing.size(0)):
+ A = 0.9 + 0.1 * torch.rand(N, T, dtype=torch.float64).requires_grad_()
+ X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
+ Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_()
+
+ start_time = time.perf_counter()
+ for _ in range(1000):
+ Y = pscan(A, X, Y_init)
+ duration = time.perf_counter() - start_time
+
+ ######################################################################
+