From b34425c4708409c607c629f4f45a82c1b554974e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 25 Mar 2023 21:02:38 +0100 Subject: [PATCH] Update. --- mygpt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mygpt.py b/mygpt.py index 7166788..06b56df 100755 --- a/mygpt.py +++ b/mygpt.py @@ -148,8 +148,8 @@ class QKVAttention(nn.Module): if amm_generator is None: self.amm_generator = ( - lambda d: torch.arange(d)[None, None, :, None] - < torch.arange(d)[None, None, None, :] + lambda d: torch.arange(d)[:, None] + < torch.arange(d)[None, :] ) else: self.amm_generator = amm_generator @@ -190,7 +190,7 @@ class QKVAttention(nn.Module): if self.causal: if bs_q.first == 0: - self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device) + self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device)[None, None,:,:] a = a.masked_fill( self.cache_attzero[ :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb -- 2.39.5