# A = har / (har + 1)
# G = G / har
+
+######################################################################
+
+2024 Jan 18 08:46:18 (from mygpt.py)
+
+ # warnings.warn("softmax gating", RuntimeWarning)
+
+ # G = (
+ # torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
+ # ).softmax(dim=2)
torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
).sigmoid()
- # warnings.warn("softmax gating", RuntimeWarning)
+ # Clip the gating to avoid values greater than 1 when several
+ # heads hit the same row
- # G = (
- # torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
- # ).softmax(dim=2)
+ G = G / G.sum(1, keepdim=True).clamp(min=1)
######################################################################
- # The "flashbacks"
-
- if self.training and self.proba_gate_dropout > 0.0:
- # This is a better implementation of "flashbacks".
-
- # G is NxHxExT where e is the caterpillar's row.
-
- warnings.warn("gate dropout", RuntimeWarning)
-
- kill = (
- torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout
- ).float()
-
- alpha = G / (1 - self.proba_gate_dropout)
-
- G = alpha * (1 - kill)
def recurrence(G, V, K):
- # Clip the gating to avoid values greater than 1 when several
- # heads hit the same row
-
- G = G / G.sum(1, keepdim=True).clamp(min=1)
-
# We prepare the arguments for the parallel scan
A = 1 - G.sum(1)
next_V, next_K = recurrence(G, V, K)
+ if self.training and self.proba_gate_dropout > 0.0:
+ # G is NxHxRxT where r is the caterpillar's row.
+
+ warnings.warn("gate dropout", RuntimeWarning)
+
+ kill = (
+ torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout
+ ).float()
+
+ mask = 1 - kill
+
+ masked_next_V, masked_next_K = recurrence(G * mask, V, K)
+
+ next_V = next_V.detach() + (masked_next_V - masked_next_V.detach()) / (
+ 1 - self.proba_gate_dropout
+ )
+ next_K = next_K.detach() + (masked_next_K - masked_next_K.detach()) / (
+ 1 - self.proba_gate_dropout
+ )
+
self.rec_V[:, :, t0:t1] = next_V
self.rec_K[:, :, t0:t1] = next_K