Update
[beaver.git] / mygpt.py
index 06b56df..4555b1e 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -148,8 +148,7 @@ class QKVAttention(nn.Module):
 
         if amm_generator is None:
             self.amm_generator = (
-                lambda d: torch.arange(d)[:, None]
-                < torch.arange(d)[None, :]
+                lambda d: torch.arange(d)[:, None] < torch.arange(d)[None, :]
             )
         else:
             self.amm_generator = amm_generator
@@ -190,7 +189,9 @@ class QKVAttention(nn.Module):
 
         if self.causal:
             if bs_q.first == 0:
-                self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device)[None, None,:,:]
+                self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device)[
+                    None, None, :, :
+                ]
             a = a.masked_fill(
                 self.cache_attzero[
                     :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb