X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=4555b1e4de65517051d10e5249f0cba7cc496d35;hb=a113de0d0ba103b6fb1bfdec69b550147a2a262f;hp=06b56df54e4efa7aa114e561b07df8dc9cd4ba0c;hpb=b34425c4708409c607c629f4f45a82c1b554974e;p=beaver.git diff --git a/mygpt.py b/mygpt.py index 06b56df..4555b1e 100755 --- a/mygpt.py +++ b/mygpt.py @@ -148,8 +148,7 @@ class QKVAttention(nn.Module): if amm_generator is None: self.amm_generator = ( - lambda d: torch.arange(d)[:, None] - < torch.arange(d)[None, :] + lambda d: torch.arange(d)[:, None] < torch.arange(d)[None, :] ) else: self.amm_generator = amm_generator @@ -190,7 +189,9 @@ class QKVAttention(nn.Module): if self.causal: if bs_q.first == 0: - self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device)[None, None,:,:] + self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device)[ + None, None, :, : + ] a = a.masked_fill( self.cache_attzero[ :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb