X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=4555b1e4de65517051d10e5249f0cba7cc496d35;hb=49bfa9f885bb100f0a262dcfd4ce7d10f75319d0;hp=311ff6bf4dd39c35f1d9182c8db246f1f375f5f1;hpb=2cd3f15987d2bf9050f737cd13506740ad3e90cb;p=beaver.git diff --git a/mygpt.py b/mygpt.py index 311ff6b..4555b1e 100755 --- a/mygpt.py +++ b/mygpt.py @@ -106,20 +106,16 @@ class AddPositionalEncoding(nn.Module): ) order_output = order + 1 - order_input = torch.cat( - (order.new_zeros(order.size(0), 1), order[:, :-1] + 1), 1 - ) + order_input = F.pad(order + 1, (1, -1)) - self.pe = torch.cat( - ( - pe.gather(1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1))), - pe.gather( - 1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1)) - ), - ), - 2, + pe_input = pe.gather( + 1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1)) + ) + pe_output = pe.gather( + 1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1)) ) + self.pe = torch.cat((pe_input, pe_output), 2) self.cache_y = bs.x.new(bs.x.size()) self.cache_y[:, bs.first : bs.first + bs.nb] = ( @@ -136,13 +132,27 @@ class AddPositionalEncoding(nn.Module): class QKVAttention(nn.Module): def __init__( - self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0 + self, + dim_in, + dim_qk, + dim_v, + nb_heads=1, + causal=False, + attention_dropout=0.0, + amm_generator=None, ): super().__init__() def randw(*d): return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + if amm_generator is None: + self.amm_generator = ( + lambda d: torch.arange(d)[:, None] < torch.arange(d)[None, :] + ) + else: + self.amm_generator = amm_generator + self.causal = causal self.attention_dropout = attention_dropout @@ -179,10 +189,9 @@ class QKVAttention(nn.Module): if self.causal: if bs_q.first == 0: - self.cache_attzero = ( - torch.arange(x_q.size(1), device=q.device)[None, None, :, None] - < torch.arange(x_q.size(1), device=q.device)[None, None, None, :] - ) + self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device)[ + None, None, :, : + ] a = a.masked_fill( self.cache_attzero[ :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb @@ -219,6 +228,7 @@ class MyGPT(nn.Module): causal=False, dropout=0.0, len_max=1e5, + amm_generator=None, ): super().__init__() @@ -242,6 +252,7 @@ class MyGPT(nn.Module): nb_heads=nb_heads, causal=causal, attention_dropout=dropout, + amm_generator=amm_generator, ), ), WithResidual(