Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 16 Dec 2023 13:54:40 +0000 (07:54 -0600)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 16 Dec 2023 13:54:40 +0000 (07:54 -0600)
pscan.py [new file with mode: 0755]

diff --git a/pscan.py b/pscan.py
new file mode 100755 (executable)
index 0000000..36490ff
--- /dev/null
+++ b/pscan.py
@@ -0,0 +1,77 @@
+#!/usr/bin/env python
+
+import math
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+
+def naive_rec(A, X, Y0):
+    Y = []
+    for t in range(X.size(1)):
+        if t == 0:
+            Y.append(A[:, t] * Y0 + X[:, t])
+        else:
+            Y.append(A[:, t] * Y[-1] + X[:, t])
+
+    return torch.cat([y[:, None, :] for y in Y], dim=1)
+
+
+######################################################################
+
+# A is NxTx1 and X is NxTxD
+#
+# Returns Y defined with
+#
+#           Y[:, 0] = A[:, 0] * Y0 + X[:,0]
+# for t > 0 Y[:, t] = A[:, t] * Y[:, t - 1] + X[:, t]
+
+
+def pscan_rec(A, X, Y0):
+    if X.size(1) % 2 == 1:
+        if X.size(1) == 1:
+            return A[:, :1] * Y0[:, None] + X[:, :1]
+        else:
+            Y = pscan_rec(A[:, :-1], X[:, :-1], Y0)
+            return torch.cat([Y, A[:, -1:] * Y[:, -1:] + X[:, -1:]], dim=1)
+
+    A2 = A.reshape(A.size(0), A.size(1) // 2, 2, A.size(2))
+    X2 = X.reshape(X.size(0), X.size(1) // 2, 2, X.size(2))
+
+    X_star = X2[:, :, 0].clone()
+    X_star[:, 1:] += A2[:, 1:, 0] * X2[:, :-1, 1]
+
+    A_star = A2[:, :, 0].clone()
+    A_star[:, 1:] *= A2[:, :-1, 1]
+
+    Y_star = pscan_rec(A_star, X_star, Y0)[:, :, None]
+
+    Y = torch.cat([Y_star, A2[:, :, 1, None] * Y_star + X2[:, :, 1, None]], dim=2)
+
+    Y = Y.reshape(Y.size(0), -1, Y.size(-1))
+
+    return Y
+
+
+######################################################################
+
+N, T, D = 5, 29, 12
+
+A = torch.rand(N, T, 1, dtype=torch.float64)
+X = torch.randint(10, (N, T, D), dtype=torch.float64)
+Y0 = torch.randint(10, (N, D), dtype=torch.float64)
+
+naive_Y = naive_rec(A, X, Y0)
+
+pscan_Y = pscan_rec(A, X, Y0)
+
+print((naive_Y - pscan_Y).pow(2).mean())
+
+pscan_Y1 = pscan_rec(A[:, :15], X[:, :15], Y0)
+pscan_Y2 = pscan_rec(A[:, 15:], X[:, 15:], pscan_Y1[:, -1])
+
+print((naive_Y - torch.cat([pscan_Y1, pscan_Y2], dim=1)).pow(2).mean())