Update
authorFrançois Fleuret <francois@fleuret.org>
Thu, 23 Mar 2023 18:11:47 +0000 (19:11 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 23 Mar 2023 18:11:47 +0000 (19:11 +0100)
beaver.py

index 86008f6..6a6343d 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -131,10 +131,10 @@ for n in vars(args):
 ######################################################################
 
 
-def generation_order(x, fixed_len):
+def generation_order(x, fixed_len=0):
     if args.random_regression_order:
         order = torch.rand(x.size(), device=x.device)
-        order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=x.device)
+        order[:, :fixed_len] = torch.arange(-fixed_len, 0, device=x.device)
         order = order.sort(1).indices
     else:
         order = (
@@ -147,8 +147,7 @@ def reorder(x, order, reverse=False):  # x is NxTxD1x...xDk, order is NxT'
     u = x.reshape(x.size()[:2] + (-1,))
     order = order.unsqueeze(-1).expand(-1, -1, u.size(-1))
     if reverse:
-        v = u.new(u.size())
-        v.scatter_(1, order, u)
+        v = u.new(u.size()).scatter_(1, order, u)
     else:
         v = u.gather(1, order)
     v = v.reshape(v.size()[:2] + x.size()[2:])
@@ -160,6 +159,12 @@ def shuffle(x, fixed_len):
     return reorder(x, order), order
 
 
+def eval_mygpt(model, input, mode="standard", fixed_len=0):
+    x, order = shuffle(input, fixed_len)
+    x = model(mygpt.BracketedSequence(x), mode=mode, order=order).x
+    return reorder(x, order, reverse=True)
+
+
 ######################################################################
 
 # ar_mask is a Boolean matrix of same shape as input, with 1s on the
@@ -197,9 +202,7 @@ def compute_perplexity(model, task, fixed_len, split="train"):
 
         for input in task.batches(split=split):
             input = input.to(device)
-            x, order = shuffle(input, fixed_len)
-            x = model(mygpt.BracketedSequence(x), order=order).x
-            output = reorder(x, order, reverse=True)
+            output = eval_mygpt(model, input, fixed_len=fixed_len)
             loss = F.cross_entropy(output.transpose(1, 2), input)
             acc_loss += loss.item() * input.size(0)
             nb_samples += input.size(0)
@@ -263,9 +266,9 @@ def oneshot(gpt, task):
 
         acc_train_loss, nb_train_samples = 0, 0
         for mazes, policies in task.policy_batches(split="train"):
-            x, order = shuffle(mazes, task.height * task.width)
-            x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
-            output_gpt = reorder(x, order, reverse=True)
+            output_gpt = eval_mygpt(
+                gpt, mazes, mode=args.oneshot_input, fixed_len=task.height * task.width
+            )
             output = model(output_gpt)
 
             loss = compute_loss(mazes, output, policies, task.height, task.width)
@@ -278,9 +281,9 @@ def oneshot(gpt, task):
 
         acc_test_loss, nb_test_samples = 0, 0
         for mazes, policies in task.policy_batches(split="test"):
-            x, order = shuffle(mazes, task.height * task.width)
-            x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
-            output_gpt = reorder(x, order, reverse=True)
+            output_gpt = eval_mygpt(
+                gpt, mazes, mode=args.oneshot_input, fixed_len=task.height * task.width
+            )
             output = model(output_gpt)
             loss = compute_loss(mazes, output, policies, task.height, task.width)
             acc_test_loss += loss.item() * mazes.size(0)
@@ -293,9 +296,9 @@ def oneshot(gpt, task):
         # -------------------
         mazes = task.test_input[:32, : task.height * task.width]
         policies = task.test_policies[:32]
-        x, order = shuffle(mazes, task.height * task.width)
-        x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
-        output_gpt = reorder(x, order, reverse=True)
+        output_gpt = eval_mygpt(
+            gpt, mazes, mode=args.oneshot_input, fixed_len=task.height * task.width
+        )
         output = model(output_gpt)
         if args.oneshot_output == "policy":
             targets = policies.permute(0, 2, 1)
@@ -628,9 +631,9 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
 
     for input in task.batches(split="train"):
         input = input.to(device)
-        x, order = shuffle(input, task.height * task.width)
-        x = model(mygpt.BracketedSequence(x), order=order).x
-        output = reorder(x, order, reverse=True)
+        output = eval_mygpt(
+            model, input, mode=args.oneshot_input, fixed_len=task.height * task.width
+        )
         loss = F.cross_entropy(output.transpose(1, 2), input)
         acc_train_loss += loss.item() * input.size(0)
         nb_train_samples += input.size(0)