X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=3a48cdbb793160ea9c88875d4f353b6a89555477;hb=3dd98b99909b2bca323673263874e2abb39ac10c;hp=9d3abb62cc8b6d95ce8be5b291d1c9e36e7f100d;hpb=cebc20b3608a41bfd27b2ab9d950c082f9b7ea89;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index 9d3abb6..3a48cdb 100755 --- a/mygpt.py +++ b/mygpt.py @@ -190,6 +190,8 @@ class DumbRec(nn.Module): nb_lines, attention_dropout=0.0, len_max=1e5, + logger=print, + **kwargs, ): super().__init__() @@ -319,6 +321,8 @@ class KVRec(nn.Module): nb_lines, attention_dropout=0.0, len_max=1e5, + logger=print, + **kwargs, ): super().__init__() @@ -471,6 +475,8 @@ class Caterpillar(nn.Module): caterpillar_height, attention_dropout=0.0, len_max=1e5, + logger=print, + **kwargs, ): super().__init__() @@ -485,14 +491,29 @@ class Caterpillar(nn.Module): self.caterpillar_height = caterpillar_height self.attention_dropout = attention_dropout - self.proba_gate_dropout = 0.0 + ###################################################################### + # sup_args + + x = kwargs.get("gate_dropout") + if x is None: + self.proba_gate_dropout = 0.0 + else: + self.proba_gate_dropout = float(x) + + logger(f"self.proba_gate_dropout {self.proba_gate_dropout}") + + x = kwargs.get("default_bg") + if x is None: + default_bg = -math.log(caterpillar_height - 1) + else: + default_bg = float(x) + + logger(f"default_bg {default_bg}") + + ###################################################################### self.w_G = randw(nb_heads, caterpillar_height, dim_model) - self.b_G = nn.Parameter( - torch.full( - (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1) - ) - ) + self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg)) self.w_K = randw(nb_heads, dim_qk, dim_model) self.w_V = randw(nb_heads, dim_v, dim_model) @@ -565,21 +586,11 @@ class Caterpillar(nn.Module): torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None] ).sigmoid() - ###################################################################### - # Roll the gating indexes - - warnings.warn("rotating barrel", RuntimeWarning) - n_barrel = torch.arange(N, device=G.device)[:, None, None, None] - h_barrel = torch.arange(H, device=G.device)[None, :, None, None] - r_barrel = torch.arange(R, device=G.device)[None, None, :, None] - t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :] - r_barrel = (r_barrel + t_barrel + t0) % R + # warnings.warn("softmax gating", RuntimeWarning) - # print(f"({N}, {H}, {R}, {t1-t0}) {G.size()=}") - - G = G[n_barrel, h_barrel, r_barrel, t_barrel] - - # print(G.sum()) + # G = ( + # torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None] + # ).softmax(dim=2) ###################################################################### # The "flashbacks" @@ -590,37 +601,32 @@ class Caterpillar(nn.Module): # G is NxHxExT where e is the caterpillar's row. warnings.warn("gate dropout", RuntimeWarning) - epsilon = 0.5 - dropout_head = ( - (torch.rand(N, H, 1, t1 - t0, device=G.device).sort(dim=3).indices == 0) - .expand_as(G) - .float() - ) + kill = ( + torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout + ).float() - dropout_tail = dropout_head.cumsum(dim=3) - dropout_head + alpha = G / (1 - self.proba_gate_dropout) - dropout_active = ( - torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout - ).long() + G = alpha * (1 - kill) - dropout_head *= dropout_active - dropout_tail *= dropout_active + ###################################################################### + # Clip the gating to avoid values greater than 1 when several + # heads hit the same row - G = ( - G - + dropout_head * (1 - epsilon - G.detach()) - - dropout_tail * G.detach() - ) + G = G / G.sum(1, keepdim=True).clamp(min=1) ###################################################################### + # Roll the gating indexes - # We prepare the arguments for the parallel scan + # warnings.warn("rotating barrel", RuntimeWarning) - # Clip the gating to avoid values greater than 1 when several - # heads hit the same row + # r_barrel = torch.arange(R, device=G.device)[None, None, :, None] + # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :] + # r_barrel = (r_barrel + (t_barrel + t0) // L) % R + # G = G.gather(dim=2, index=r_barrel.expand_as(G)) - G = G / G.sum(1, keepdim=True).clamp(min=1) + # We prepare the arguments for the parallel scan A = 1 - G.sum(1) @@ -715,6 +721,8 @@ class QKVAttention(nn.Module): nb_heads=1, causal=False, attention_dropout=0.0, + logger=print, + **kwargs, ): super().__init__() @@ -806,6 +814,8 @@ class MyGPT(nn.Module): dropout=0.0, len_max=1e5, attention_layer="kvrec", + logger=print, + **kwargs, ): super().__init__() @@ -842,6 +852,8 @@ class MyGPT(nn.Module): nb_heads=nb_heads, causal=causal, attention_dropout=dropout, + logger=logger, + **kwargs, ) elif attention_layer == "dumbrec": return DumbRec( @@ -851,6 +863,8 @@ class MyGPT(nn.Module): nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, + logger=logger, + **kwargs, ) elif attention_layer == "kvrec": return KVRec( @@ -860,6 +874,8 @@ class MyGPT(nn.Module): nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, + logger=logger, + **kwargs, ) elif attention_layer == "caterpillar": return Caterpillar( @@ -870,6 +886,8 @@ class MyGPT(nn.Module): caterpillar_length=self.caterpillar_length, caterpillar_height=self.caterpillar_height, attention_dropout=dropout, + logger=logger, + **kwargs, ) else: raise ValueError(f"Unknown attention type {attention_layer}.")