X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=mygpt.py;h=7c9991f7d56fc3069a261b0642e4381c55bd02d9;hb=73acbc986f9c386c001117581c4fc72d2f36803a;hp=a62cf4908ba88622a3f567d082c7a94711887fde;hpb=4f5d03d3371b124121e8f9fc0ff583553fea1e38;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index a62cf49..7c9991f 100755 --- a/mygpt.py +++ b/mygpt.py @@ -190,6 +190,8 @@ class DumbRec(nn.Module): nb_lines, attention_dropout=0.0, len_max=1e5, + logger=print, + **kwargs, ): super().__init__() @@ -319,6 +321,8 @@ class KVRec(nn.Module): nb_lines, attention_dropout=0.0, len_max=1e5, + logger=print, + **kwargs, ): super().__init__() @@ -471,6 +475,8 @@ class Caterpillar(nn.Module): caterpillar_height, attention_dropout=0.0, len_max=1e5, + logger=print, + **kwargs, ): super().__init__() @@ -487,12 +493,14 @@ class Caterpillar(nn.Module): self.proba_gate_dropout = 0.0 + default_b_G = kwargs.get("default_b_G") + if default_b_G is None: + default_b_G = -math.log(caterpillar_height - 1) + + logger(f"default_b_G {default_b_G}") + self.w_G = randw(nb_heads, caterpillar_height, dim_model) - self.b_G = nn.Parameter( - torch.full( - (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1) - ) - ) + self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_b_G)) self.w_K = randw(nb_heads, dim_qk, dim_model) self.w_V = randw(nb_heads, dim_v, dim_model) @@ -565,15 +573,20 @@ class Caterpillar(nn.Module): torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None] ).sigmoid() + # Clip the gating to avoid values greater than 1 when several + # heads hit the same row + + G = G / G.sum(1, keepdim=True).clamp(min=1) + ###################################################################### # Roll the gating indexes - warnings.warn("rotating barrel", RuntimeWarning) + # warnings.warn("rotating barrel", RuntimeWarning) - r_barrel = torch.arange(R, device=G.device)[None, None, :, None] - t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :] - r_barrel = (r_barrel + (t_barrel + t0) // L) % R - G = G.gather(dim=2, index=r_barrel.expand_as(G)) + # r_barrel = torch.arange(R, device=G.device)[None, None, :, None] + # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :] + # r_barrel = (r_barrel + (t_barrel + t0) // L) % R + # G = G.gather(dim=2, index=r_barrel.expand_as(G)) ###################################################################### # The "flashbacks" @@ -611,11 +624,6 @@ class Caterpillar(nn.Module): # We prepare the arguments for the parallel scan - # Clip the gating to avoid values greater than 1 when several - # heads hit the same row - - G = G / G.sum(1, keepdim=True).clamp(min=1) - A = 1 - G.sum(1) # warnings.warn("harmonic recurrence", RuntimeWarning) @@ -709,6 +717,8 @@ class QKVAttention(nn.Module): nb_heads=1, causal=False, attention_dropout=0.0, + logger=print, + **kwargs, ): super().__init__() @@ -800,6 +810,8 @@ class MyGPT(nn.Module): dropout=0.0, len_max=1e5, attention_layer="kvrec", + logger=print, + **kwargs, ): super().__init__() @@ -836,6 +848,8 @@ class MyGPT(nn.Module): nb_heads=nb_heads, causal=causal, attention_dropout=dropout, + logger=logger, + **kwargs, ) elif attention_layer == "dumbrec": return DumbRec( @@ -845,6 +859,8 @@ class MyGPT(nn.Module): nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, + logger=logger, + **kwargs, ) elif attention_layer == "kvrec": return KVRec( @@ -854,6 +870,8 @@ class MyGPT(nn.Module): nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, + logger=logger, + **kwargs, ) elif attention_layer == "caterpillar": return Caterpillar( @@ -864,6 +882,8 @@ class MyGPT(nn.Module): caterpillar_length=self.caterpillar_length, caterpillar_height=self.caterpillar_height, attention_dropout=dropout, + logger=logger, + **kwargs, ) else: raise ValueError(f"Unknown attention type {attention_layer}.")