Update
authorFrançois Fleuret <francois@fleuret.org>
Fri, 24 Mar 2023 19:51:50 +0000 (20:51 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 24 Mar 2023 19:51:50 +0000 (20:51 +0100)
beaver.py

index 4f41832..5ee468e 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -524,7 +524,10 @@ amm_generator = None
 if args.noncausal_prompt:
     amm_generator = lambda d: torch.logical_and(
         torch.arange(d)[None, None, :, None] < torch.arange(d)[None, None, None, :],
-        torch.arange(d)[None, None, :, None] >= d // 2,
+        torch.logical_or(
+            torch.arange(d)[None, None, :, None] >= d // 2,
+            torch.arange(d)[None, None, None, :] >= d // 2,
+        ),
     )
 
 model = mygpt.MyGPT(