X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=099847c95d9404d477b069d8cdf78a62304b3784;hb=2434c00a82ebb0b23f45d891cc9f80324e3200bd;hp=633ad642c19a3045064ef858c0ee494a7c733425;hpb=6e87fe0cb8bd8a0042bbf7b2ede9d8ed0372fb6b;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index 633ad64..099847c 100755 --- a/mygpt.py +++ b/mygpt.py @@ -190,6 +190,8 @@ class DumbRec(nn.Module): nb_lines, attention_dropout=0.0, len_max=1e5, + logger=print, + **kwargs, ): super().__init__() @@ -319,6 +321,8 @@ class KVRec(nn.Module): nb_lines, attention_dropout=0.0, len_max=1e5, + logger=print, + **kwargs, ): super().__init__() @@ -471,6 +475,8 @@ class Caterpillar(nn.Module): caterpillar_height, attention_dropout=0.0, len_max=1e5, + logger=print, + **kwargs, ): super().__init__() @@ -487,12 +493,16 @@ class Caterpillar(nn.Module): self.proba_gate_dropout = 0.0 + default_bg = kwargs.get("default_bg") + if default_bg is None: + default_bg = -math.log(caterpillar_height - 1) + else: + default_bg = float(default_bg) + + logger(f"default_bg {default_bg}") + self.w_G = randw(nb_heads, caterpillar_height, dim_model) - self.b_G = nn.Parameter( - torch.full( - (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1) - ) - ) + self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg)) self.w_K = randw(nb_heads, dim_qk, dim_model) self.w_V = randw(nb_heads, dim_v, dim_model) @@ -565,24 +575,20 @@ class Caterpillar(nn.Module): torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None] ).sigmoid() - ###################################################################### - # Roll the gating indexes - - warnings.warn("rotating barrel", RuntimeWarning) + # Clip the gating to avoid values greater than 1 when several + # heads hit the same row - # print(f"SANITY2 {N=} {H=} {R=} {t0=} {t1=} {G.size()=}") + G = G / G.sum(1, keepdim=True).clamp(min=1) - n_barrel = torch.arange(N, device=G.device)[:, None, None, None] - h_barrel = torch.arange(H, device=G.device)[None, :, None, None] - r_barrel = torch.arange(R, device=G.device)[None, None, :, None] - t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :] - r_barrel = (r_barrel + (t_barrel + t0) // L) % R + ###################################################################### + # Roll the gating indexes - # GG = G.gather(dim=2,index=r_barrel) - G = G[n_barrel, h_barrel, r_barrel, t_barrel] + # warnings.warn("rotating barrel", RuntimeWarning) - # print("SANITY", (GG-G).abs()) - # exit(0) + # r_barrel = torch.arange(R, device=G.device)[None, None, :, None] + # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :] + # r_barrel = (r_barrel + (t_barrel + t0) // L) % R + # G = G.gather(dim=2, index=r_barrel.expand_as(G)) ###################################################################### # The "flashbacks" @@ -620,11 +626,6 @@ class Caterpillar(nn.Module): # We prepare the arguments for the parallel scan - # Clip the gating to avoid values greater than 1 when several - # heads hit the same row - - G = G / G.sum(1, keepdim=True).clamp(min=1) - A = 1 - G.sum(1) # warnings.warn("harmonic recurrence", RuntimeWarning) @@ -718,6 +719,8 @@ class QKVAttention(nn.Module): nb_heads=1, causal=False, attention_dropout=0.0, + logger=print, + **kwargs, ): super().__init__() @@ -809,6 +812,8 @@ class MyGPT(nn.Module): dropout=0.0, len_max=1e5, attention_layer="kvrec", + logger=print, + **kwargs, ): super().__init__() @@ -845,6 +850,8 @@ class MyGPT(nn.Module): nb_heads=nb_heads, causal=causal, attention_dropout=dropout, + logger=logger, + **kwargs, ) elif attention_layer == "dumbrec": return DumbRec( @@ -854,6 +861,8 @@ class MyGPT(nn.Module): nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, + logger=logger, + **kwargs, ) elif attention_layer == "kvrec": return KVRec( @@ -863,6 +872,8 @@ class MyGPT(nn.Module): nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, + logger=logger, + **kwargs, ) elif attention_layer == "caterpillar": return Caterpillar( @@ -873,6 +884,8 @@ class MyGPT(nn.Module): caterpillar_length=self.caterpillar_length, caterpillar_height=self.caterpillar_height, attention_dropout=dropout, + logger=logger, + **kwargs, ) else: raise ValueError(f"Unknown attention type {attention_layer}.")