X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=099847c95d9404d477b069d8cdf78a62304b3784;hb=2434c00a82ebb0b23f45d891cc9f80324e3200bd;hp=9d3abb62cc8b6d95ce8be5b291d1c9e36e7f100d;hpb=cebc20b3608a41bfd27b2ab9d950c082f9b7ea89;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index 9d3abb6..099847c 100755 --- a/mygpt.py +++ b/mygpt.py @@ -190,6 +190,8 @@ class DumbRec(nn.Module): nb_lines, attention_dropout=0.0, len_max=1e5, + logger=print, + **kwargs, ): super().__init__() @@ -319,6 +321,8 @@ class KVRec(nn.Module): nb_lines, attention_dropout=0.0, len_max=1e5, + logger=print, + **kwargs, ): super().__init__() @@ -471,6 +475,8 @@ class Caterpillar(nn.Module): caterpillar_height, attention_dropout=0.0, len_max=1e5, + logger=print, + **kwargs, ): super().__init__() @@ -487,12 +493,16 @@ class Caterpillar(nn.Module): self.proba_gate_dropout = 0.0 + default_bg = kwargs.get("default_bg") + if default_bg is None: + default_bg = -math.log(caterpillar_height - 1) + else: + default_bg = float(default_bg) + + logger(f"default_bg {default_bg}") + self.w_G = randw(nb_heads, caterpillar_height, dim_model) - self.b_G = nn.Parameter( - torch.full( - (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1) - ) - ) + self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg)) self.w_K = randw(nb_heads, dim_qk, dim_model) self.w_V = randw(nb_heads, dim_v, dim_model) @@ -565,21 +575,20 @@ class Caterpillar(nn.Module): torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None] ).sigmoid() - ###################################################################### - # Roll the gating indexes + # Clip the gating to avoid values greater than 1 when several + # heads hit the same row - warnings.warn("rotating barrel", RuntimeWarning) - n_barrel = torch.arange(N, device=G.device)[:, None, None, None] - h_barrel = torch.arange(H, device=G.device)[None, :, None, None] - r_barrel = torch.arange(R, device=G.device)[None, None, :, None] - t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :] - r_barrel = (r_barrel + t_barrel + t0) % R + G = G / G.sum(1, keepdim=True).clamp(min=1) - # print(f"({N}, {H}, {R}, {t1-t0}) {G.size()=}") + ###################################################################### + # Roll the gating indexes - G = G[n_barrel, h_barrel, r_barrel, t_barrel] + # warnings.warn("rotating barrel", RuntimeWarning) - # print(G.sum()) + # r_barrel = torch.arange(R, device=G.device)[None, None, :, None] + # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :] + # r_barrel = (r_barrel + (t_barrel + t0) // L) % R + # G = G.gather(dim=2, index=r_barrel.expand_as(G)) ###################################################################### # The "flashbacks" @@ -617,11 +626,6 @@ class Caterpillar(nn.Module): # We prepare the arguments for the parallel scan - # Clip the gating to avoid values greater than 1 when several - # heads hit the same row - - G = G / G.sum(1, keepdim=True).clamp(min=1) - A = 1 - G.sum(1) # warnings.warn("harmonic recurrence", RuntimeWarning) @@ -715,6 +719,8 @@ class QKVAttention(nn.Module): nb_heads=1, causal=False, attention_dropout=0.0, + logger=print, + **kwargs, ): super().__init__() @@ -806,6 +812,8 @@ class MyGPT(nn.Module): dropout=0.0, len_max=1e5, attention_layer="kvrec", + logger=print, + **kwargs, ): super().__init__() @@ -842,6 +850,8 @@ class MyGPT(nn.Module): nb_heads=nb_heads, causal=causal, attention_dropout=dropout, + logger=logger, + **kwargs, ) elif attention_layer == "dumbrec": return DumbRec( @@ -851,6 +861,8 @@ class MyGPT(nn.Module): nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, + logger=logger, + **kwargs, ) elif attention_layer == "kvrec": return KVRec( @@ -860,6 +872,8 @@ class MyGPT(nn.Module): nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, + logger=logger, + **kwargs, ) elif attention_layer == "caterpillar": return Caterpillar( @@ -870,6 +884,8 @@ class MyGPT(nn.Module): caterpillar_length=self.caterpillar_length, caterpillar_height=self.caterpillar_height, attention_dropout=dropout, + logger=logger, + **kwargs, ) else: raise ValueError(f"Unknown attention type {attention_layer}.")