X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=67c5cfd96ff11a5fd04b88eb6b26a72c90e97ddb;hb=6ca2c05c7470e92cd591fe1b8de33c80c1b27180;hp=9a02bcd13c0cd2dbd9c02df33699fb3b07a4ec87;hpb=318946b9a800dcc07531053e345bda46440f617f;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index 9a02bcd..67c5cfd 100755 --- a/mygpt.py +++ b/mygpt.py @@ -21,7 +21,7 @@ from torch.nn import functional as F import ffutils -from blanket import blanket +# from blanket import blanket # import memload @@ -569,7 +569,7 @@ class Caterpillar(nn.Module): V = torch.einsum("ntc,hdc->nhtd", X, self.w_V) K = torch.einsum("ntc,hdc->nhtd", X, self.w_K) - V, K = blanket(V), blanket(K) + # V, K = blanket(V), blanket(K) ###################################################################### # Compute the recurrent state @@ -673,7 +673,7 @@ class Caterpillar(nn.Module): Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q) - Q = blanket(Q) + # Q = blanket(Q) # We build tensors NxHxTxRxL where N is the sample index, H # the head, T the time, R the row in the caterpillar, and L @@ -712,7 +712,7 @@ class Caterpillar(nn.Module): # Compute the final output - Y = blanket(Y) + # Y = blanket(Y) self.cache_Y[:, t0:t1] = Y @ self.w_O