From: François Fleuret Date: Sat, 25 Mar 2023 20:02:38 +0000 (+0100) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=beaver.git;a=commitdiff_plain;h=b34425c4708409c607c629f4f45a82c1b554974e Update. --- diff --git a/mygpt.py b/mygpt.py index 7166788..06b56df 100755 --- a/mygpt.py +++ b/mygpt.py @@ -148,8 +148,8 @@ class QKVAttention(nn.Module): if amm_generator is None: self.amm_generator = ( - lambda d: torch.arange(d)[None, None, :, None] - < torch.arange(d)[None, None, None, :] + lambda d: torch.arange(d)[:, None] + < torch.arange(d)[None, :] ) else: self.amm_generator = amm_generator @@ -190,7 +190,7 @@ class QKVAttention(nn.Module): if self.causal: if bs_q.first == 0: - self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device) + self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device)[None, None,:,:] a = a.masked_fill( self.cache_attzero[ :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb