Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 8 Aug 2024 07:27:34 +0000 (09:27 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 8 Aug 2024 07:27:34 +0000 (09:27 +0200)
main.py
mygpt.py

diff --git a/main.py b/main.py
index 86eafea..c77a7f3 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -990,6 +990,11 @@ def train_complexifier(model_gen, model_pred1, model_pred2):
 
 models = []
 
+
+def compute_causal_attzero(t_q, t_k):
+    return t_q < t_k
+
+
 for k in range(args.nb_gpts):
     log_string(f"creating model {k} and its w_quizzes")
 
@@ -1000,7 +1005,7 @@ for k in range(args.nb_gpts):
         dim_hidden=args.dim_hidden,
         nb_heads=args.nb_heads,
         nb_blocks=args.nb_blocks,
-        causal=True,
+        compute_attzero=compute_causal_attzero,
         dropout=args.dropout,
     ).to(main_device)
 
@@ -1144,7 +1149,7 @@ if args.test == "generator":
         dim_hidden=args.dim_hidden,
         nb_heads=args.nb_heads,
         nb_blocks=args.nb_blocks,
-        causal=True,
+        compute_attzero=compute_causal_attzero,
         dropout=args.dropout,
     ).to(main_device)
 
index 15ed80e..2706143 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -145,7 +145,7 @@ class QKVAttention(nn.Module):
         dim_qk,
         dim_v,
         nb_heads=1,
-        causal=False,
+        compute_attzero=None,
         attention_dropout=0.0,
     ):
         super().__init__()
@@ -153,7 +153,7 @@ class QKVAttention(nn.Module):
         def randw(*d):
             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
 
-        self.causal = causal
+        self.compute_attzero = compute_attzero
         self.attention_dropout = attention_dropout
         self.record_attention = False
 
@@ -165,10 +165,6 @@ class QKVAttention(nn.Module):
     def forward(self, bs_q):
         x_q = bs_q.x
 
-        assert (
-            self.causal or bs_q.complete()
-        ), "Partial evaluation is only possible for causal models"
-
         if bs_q.first == 0:
             self.cache_k = x_q.new_zeros(
                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
@@ -193,12 +189,12 @@ class QKVAttention(nn.Module):
             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
         ) / math.sqrt(self.w_q.size(1))
 
-        if self.causal:
+        if self.compute_attzero is not None:
             if bs_q.first == 0:
-                self.cache_attzero = (
-                    torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
-                    < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
-                )
+                self.cache_attzero = self.compute_attzero(
+                    torch.arange(x_q.size(1), device=q.device)[:, None],
+                    torch.arange(x_q.size(1), device=q.device)[None, :],
+                )[None, None, :, :]
             a = a.masked_fill(
                 self.cache_attzero[
                     :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
@@ -251,7 +247,7 @@ class MyGPT(nn.Module):
         dim_hidden,
         nb_heads,
         nb_blocks,
-        causal=False,
+        compute_attzero=None,
         autoencoder_dim=-1,
         dropout=0.0,
         len_max=1e5,
@@ -281,7 +277,7 @@ class MyGPT(nn.Module):
                         dim_qk=dim_keys,
                         dim_v=dim_model // nb_heads,
                         nb_heads=nb_heads,
-                        causal=causal,
+                        compute_attzero=compute_attzero,
                         attention_dropout=dropout,
                     ),
                 ),
@@ -407,7 +403,6 @@ if __name__ == "__main__":
         nb_heads=2,
         nb_blocks=2,
         dropout=0.1,
-        causal=True,
     )
 
     model.eval()