From e3d5af800ccd197580265709c4499bf281beecb8 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 18 Jan 2024 08:54:04 +0100 Subject: [PATCH] Update. --- fridge | 10 ++++++++++ mygpt.py | 48 +++++++++++++++++++++++------------------------- 2 files changed, 33 insertions(+), 25 deletions(-) diff --git a/fridge b/fridge index d09e92d..2cc6d01 100644 --- a/fridge +++ b/fridge @@ -292,3 +292,13 @@ class Calibrator: # A = har / (har + 1) # G = G / har + +###################################################################### + +2024 Jan 18 08:46:18 (from mygpt.py) + + # warnings.warn("softmax gating", RuntimeWarning) + + # G = ( + # torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None] + # ).softmax(dim=2) diff --git a/mygpt.py b/mygpt.py index a27b99e..492a9bb 100755 --- a/mygpt.py +++ b/mygpt.py @@ -597,36 +597,14 @@ class Caterpillar(nn.Module): torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None] ).sigmoid() - # warnings.warn("softmax gating", RuntimeWarning) + # Clip the gating to avoid values greater than 1 when several + # heads hit the same row - # G = ( - # torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None] - # ).softmax(dim=2) + G = G / G.sum(1, keepdim=True).clamp(min=1) ###################################################################### - # The "flashbacks" - - if self.training and self.proba_gate_dropout > 0.0: - # This is a better implementation of "flashbacks". - - # G is NxHxExT where e is the caterpillar's row. - - warnings.warn("gate dropout", RuntimeWarning) - - kill = ( - torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout - ).float() - - alpha = G / (1 - self.proba_gate_dropout) - - G = alpha * (1 - kill) def recurrence(G, V, K): - # Clip the gating to avoid values greater than 1 when several - # heads hit the same row - - G = G / G.sum(1, keepdim=True).clamp(min=1) - # We prepare the arguments for the parallel scan A = 1 - G.sum(1) @@ -663,6 +641,26 @@ class Caterpillar(nn.Module): next_V, next_K = recurrence(G, V, K) + if self.training and self.proba_gate_dropout > 0.0: + # G is NxHxRxT where r is the caterpillar's row. + + warnings.warn("gate dropout", RuntimeWarning) + + kill = ( + torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout + ).float() + + mask = 1 - kill + + masked_next_V, masked_next_K = recurrence(G * mask, V, K) + + next_V = next_V.detach() + (masked_next_V - masked_next_V.detach()) / ( + 1 - self.proba_gate_dropout + ) + next_K = next_K.detach() + (masked_next_K - masked_next_K.detach()) / ( + 1 - self.proba_gate_dropout + ) + self.rec_V[:, :, t0:t1] = next_V self.rec_K[:, :, t0:t1] = next_K -- 2.20.1