Update
authorFrançois Fleuret <francois@fleuret.org>
Mon, 27 Mar 2023 13:38:16 +0000 (15:38 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 27 Mar 2023 13:38:16 +0000 (15:38 +0200)
beaver.py

index 5407859..074e137 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -133,18 +133,6 @@ for n in vars(args):
 ######################################################################
 
 
-def generation_order(x, prompt_len=0):
-    if args.random_regression_order:
-        order = torch.rand(x.size(), device=x.device)
-        order[:, :prompt_len] = torch.arange(-prompt_len, 0, device=x.device)
-        order = order.sort(1).indices
-    else:
-        order = (
-            torch.arange(x.size(1), device=x.device).unsqueeze(0).expand(x.size(0), -1)
-        )
-    return order
-
-
 def reorder(x, order, reverse=False):  # x is NxTxD1x...xDk, order is NxT'
     u = x.reshape(x.size()[:2] + (-1,))
     order = order.unsqueeze(-1).expand(-1, -1, u.size(-1))
@@ -157,7 +145,14 @@ def reorder(x, order, reverse=False):  # x is NxTxD1x...xDk, order is NxT'
 
 
 def shuffle(x, prompt_len):
-    order = generation_order(x, prompt_len)
+    if args.random_regression_order:
+        order = torch.rand(x.size(), device=x.device)
+        order[:, :prompt_len] = torch.arange(-prompt_len, 0, device=x.device)
+        order = order.sort(1).indices
+    else:
+        order = (
+            torch.arange(x.size(1), device=x.device).unsqueeze(0).expand(x.size(0), -1)
+        )
     return reorder(x, order), order