From 8fdce4736a05a37d0f8706148dd743bce123fe1b Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 19 Jan 2024 00:19:12 +0100 Subject: [PATCH] Update. --- mygpt.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/mygpt.py b/mygpt.py index 5451584..fb24b9a 100755 --- a/mygpt.py +++ b/mygpt.py @@ -644,17 +644,18 @@ class Caterpillar(nn.Module): warnings.warn("gate dropout", RuntimeWarning) - # kill = ( - # torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout - # ).float() - + # Pick a point in each of the NxHxR timeline and set this + # entry and the following to 1 kill = ( torch.rand(N, H, R, t1 - t0, device=G.device).sort(dim=3).indices == 0 ).cumsum(dim=3) + + # Keep these mask for only some of the NxHxR kill = kill * ( torch.rand(N, H, R, 1, device=G.device) <= self.proba_gate_dropout ) + # The coefficient to keep are the complementary mask = 1 - kill masked_next_V, masked_next_K = recurrence(G * mask, V, K) @@ -674,8 +675,8 @@ class Caterpillar(nn.Module): Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q) - # We build tensors NxHxTxFxL where N is the sample index, H - # the head, T the time, F the row in the caterpillar, and L + # We build tensors NxHxTxRxL where N is the sample index, H + # the head, T the time, R the row in the caterpillar, and L # the column in the caterpillar windowed_V = moving_window( @@ -689,7 +690,7 @@ class Caterpillar(nn.Module): # We have an attention score for each of the RxL values ar = torch.einsum( - "nhtd,nftld->nhtfl", + "nhtd,nrtld->nhtrl", Q, windowed_K, ) / math.sqrt(DK) -- 2.39.5