Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 25 Mar 2023 20:02:38 +0000 (21:02 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 25 Mar 2023 20:02:38 +0000 (21:02 +0100)
mygpt.py

index 7166788..06b56df 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -148,8 +148,8 @@ class QKVAttention(nn.Module):
 
         if amm_generator is None:
             self.amm_generator = (
-                lambda d: torch.arange(d)[None, None, :, None]
-                < torch.arange(d)[None, None, None, :]
+                lambda d: torch.arange(d)[:, None]
+                < torch.arange(d)[None, :]
             )
         else:
             self.amm_generator = amm_generator
@@ -190,7 +190,7 @@ class QKVAttention(nn.Module):
 
         if self.causal:
             if bs_q.first == 0:
-                self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device)
+                self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device)[None, None,:,:]
             a = a.masked_fill(
                 self.cache_attzero[
                     :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb