X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=pscan.py;h=891ec5578e118f48d8f5e933d3f4603e1703e9b1;hb=refs%2Fheads%2Fmaster;hp=3526c31ed601270bb507f2c7771284ad455b5c62;hpb=674dd7c7adde6b4a9aaa5afd57dbe1d063a47fcc;p=pytorch.git diff --git a/pscan.py b/pscan.py index 3526c31..891ec55 100755 --- a/pscan.py +++ b/pscan.py @@ -55,8 +55,8 @@ class PScan(torch.autograd.Function): # # returns Y of same shape as X, with # - # Y[:,t] = A[:,0] * Y_init + X[:,0] if t == 0 - # = A[:,t] * Y[:,t-1] + X[:,t] otherwise + # Y[:, t] = A[:, 0] * Y_init + X[:, 0] if t == 0 + # = A[:, t] * Y[:, t-1] + X[:, t] otherwise @staticmethod def forward(ctx, A, X, Y_init): @@ -114,3 +114,8 @@ if __name__ == "__main__": print((gA - gA_ref).norm()) print((gX - gX_ref).norm()) print((gY_init - gY_init_ref).norm()) + + Y1 = pscan(A[:, : T // 2], X[:, : T // 2], Y_init) + Y2 = pscan(A[:, T // 2 :], X[:, T // 2 :], Y1[:, -1]) + + print((Y - torch.cat([Y1, Y2], dim=1)).norm())