Update
authorFrançois Fleuret <francois@fleuret.org>
Fri, 24 Mar 2023 17:34:30 +0000 (18:34 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 24 Mar 2023 17:34:30 +0000 (18:34 +0100)
beaver.py
mygpt.py

index 6a6343d..4f41832 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -66,6 +66,8 @@ parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa
 
 parser.add_argument("--random_regression_order", action="store_true", default=False)
 
+parser.add_argument("--noncausal_prompt", action="store_true", default=False)
+
 parser.add_argument("--no_checkpoint", action="store_true", default=False)
 
 parser.add_argument("--overwrite_results", action="store_true", default=False)
@@ -517,6 +519,14 @@ log_string(f"vocabulary_size {vocabulary_size}")
 
 ##############################
 
+amm_generator = None
+
+if args.noncausal_prompt:
+    amm_generator = lambda d: torch.logical_and(
+        torch.arange(d)[None, None, :, None] < torch.arange(d)[None, None, None, :],
+        torch.arange(d)[None, None, :, None] >= d // 2,
+    )
+
 model = mygpt.MyGPT(
     vocabulary_size=vocabulary_size,
     dim_model=args.dim_model,
@@ -526,6 +536,7 @@ model = mygpt.MyGPT(
     nb_blocks=args.nb_blocks,
     causal=True,
     dropout=args.dropout,
+    amm_generator=amm_generator,
 )
 
 model.to(device)
index 75adbf6..7166788 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -132,13 +132,28 @@ class AddPositionalEncoding(nn.Module):
 
 class QKVAttention(nn.Module):
     def __init__(
-        self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0
+        self,
+        dim_in,
+        dim_qk,
+        dim_v,
+        nb_heads=1,
+        causal=False,
+        attention_dropout=0.0,
+        amm_generator=None,
     ):
         super().__init__()
 
         def randw(*d):
             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
 
+        if amm_generator is None:
+            self.amm_generator = (
+                lambda d: torch.arange(d)[None, None, :, None]
+                < torch.arange(d)[None, None, None, :]
+            )
+        else:
+            self.amm_generator = amm_generator
+
         self.causal = causal
         self.attention_dropout = attention_dropout
 
@@ -175,10 +190,7 @@ class QKVAttention(nn.Module):
 
         if self.causal:
             if bs_q.first == 0:
-                self.cache_attzero = (
-                    torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
-                    < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
-                )
+                self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device)
             a = a.masked_fill(
                 self.cache_attzero[
                     :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
@@ -215,6 +227,7 @@ class MyGPT(nn.Module):
         causal=False,
         dropout=0.0,
         len_max=1e5,
+        amm_generator=None,
     ):
         super().__init__()
 
@@ -238,6 +251,7 @@ class MyGPT(nn.Module):
                         nb_heads=nb_heads,
                         causal=causal,
                         attention_dropout=dropout,
+                        amm_generator=amm_generator,
                     ),
                 ),
                 WithResidual(