Update
authorFrançois Fleuret <francois@fleuret.org>
Wed, 22 Mar 2023 22:10:53 +0000 (23:10 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 22 Mar 2023 22:10:53 +0000 (23:10 +0100)
beaver.py
mygpt.py

index 49cb1f6..8fe9a9b 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -131,13 +131,15 @@ for n in vars(args):
 ######################################################################
 
 
-def random_order(result, fixed_len):
+def generation_order(x, fixed_len):
     if args.random_regression_order:
-        order = torch.rand(result.size(), device=result.device)
+        order = torch.rand(x.size(), device=x.device)
         order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=order.device)
         return order.sort(1).indices
     else:
-        return torch.arange(result.size(1)).unsqueeze(0).expand(result.size(0), -1)
+        return (
+            torch.arange(x.size(1), device=x.device).unsqueeze(0).expand(x.size(0), -1)
+        )
 
 
 def shuffle(x, order, reorder=False):
@@ -184,7 +186,7 @@ def compute_perplexity(model, split="train"):
 
         for input in task.batches(split=split):
             input = input.to(device)
-            order = random_order(input, task.height * task.width)
+            order = generation_order(input, task.height * task.width)
             input = shuffle(input, order)
             output = model(mygpt.BracketedSequence(input), order=order).x
             loss = F.cross_entropy(output.transpose(1, 2), input)
@@ -250,7 +252,7 @@ def oneshot(gpt, task):
 
         acc_train_loss, nb_train_samples = 0, 0
         for mazes, policies in task.policy_batches(split="train"):
-            order = random_order(mazes, task.height * task.width)
+            order = generation_order(mazes, task.height * task.width)
             x = shuffle(mazes, order)
             x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
             output_gpt = shuffle(x, order, reorder=True)
@@ -266,7 +268,7 @@ def oneshot(gpt, task):
 
         acc_test_loss, nb_test_samples = 0, 0
         for mazes, policies in task.policy_batches(split="test"):
-            order = random_order(mazes, task.height * task.width)
+            order = generation_order(mazes, task.height * task.width)
             x = shuffle(mazes, order)
             x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
             output_gpt = shuffle(x, order, reorder=True)
@@ -282,7 +284,7 @@ def oneshot(gpt, task):
         # -------------------
         mazes = task.test_input[:32, : task.height * task.width]
         policies = task.test_policies[:32]
-        order = random_order(mazes, task.height * task.width)
+        order = generation_order(mazes, task.height * task.width)
         x = shuffle(mazes, order)
         x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
         output_gpt = shuffle(x, order, reorder=True)
@@ -424,7 +426,7 @@ class TaskMaze(Task):
             ar_mask = result.new_zeros(result.size())
             ar_mask[:, self.height * self.width :] = 1
             result *= 1 - ar_mask
-            order = random_order(result, self.height * self.width)
+            order = generation_order(result, self.height * self.width)
             masked_inplace_autoregression(
                 model, self.batch_size, result, ar_mask, order=order
             )
@@ -606,7 +608,7 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
 
     for input in task.batches(split="train"):
         input = input.to(device)
-        order = random_order(input, task.height * task.width)
+        order = generation_order(input, task.height * task.width)
         input = shuffle(input, order)
         output = model(mygpt.BracketedSequence(input), order=order).x
         loss = F.cross_entropy(output.transpose(1, 2), input)
index 311ff6b..75adbf6 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -106,20 +106,16 @@ class AddPositionalEncoding(nn.Module):
             )
 
             order_output = order + 1
-            order_input = torch.cat(
-                (order.new_zeros(order.size(0), 1), order[:, :-1] + 1), 1
-            )
+            order_input = F.pad(order + 1, (1, -1))
 
-            self.pe = torch.cat(
-                (
-                    pe.gather(1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1))),
-                    pe.gather(
-                        1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1))
-                    ),
-                ),
-                2,
+            pe_input = pe.gather(
+                1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1))
+            )
+            pe_output = pe.gather(
+                1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1))
             )
 
+            self.pe = torch.cat((pe_input, pe_output), 2)
             self.cache_y = bs.x.new(bs.x.size())
 
         self.cache_y[:, bs.first : bs.first + bs.nb] = (