Update
[beaver.git] / beaver.py
index 6ec0fb2..49cb1f6 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -64,6 +64,8 @@ parser.add_argument("--dropout", type=float, default=0.1)
 
 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
 
+parser.add_argument("--random_regression_order", action="store_true", default=False)
+
 parser.add_argument("--no_checkpoint", action="store_true", default=False)
 
 parser.add_argument("--overwrite_results", action="store_true", default=False)
@@ -130,9 +132,12 @@ for n in vars(args):
 
 
 def random_order(result, fixed_len):
-    order = torch.rand(result.size(), device=result.device)
-    order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=order.device)
-    return order.sort(1).indices
+    if args.random_regression_order:
+        order = torch.rand(result.size(), device=result.device)
+        order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=order.device)
+        return order.sort(1).indices
+    else:
+        return torch.arange(result.size(1)).unsqueeze(0).expand(result.size(0), -1)
 
 
 def shuffle(x, order, reorder=False):
@@ -245,7 +250,7 @@ def oneshot(gpt, task):
 
         acc_train_loss, nb_train_samples = 0, 0
         for mazes, policies in task.policy_batches(split="train"):
-            order = random_order(input, task.height * task.width)
+            order = random_order(mazes, task.height * task.width)
             x = shuffle(mazes, order)
             x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
             output_gpt = shuffle(x, order, reorder=True)
@@ -261,7 +266,7 @@ def oneshot(gpt, task):
 
         acc_test_loss, nb_test_samples = 0, 0
         for mazes, policies in task.policy_batches(split="test"):
-            order = random_order(input, task.height * task.width)
+            order = random_order(mazes, task.height * task.width)
             x = shuffle(mazes, order)
             x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
             output_gpt = shuffle(x, order, reorder=True)
@@ -277,7 +282,7 @@ def oneshot(gpt, task):
         # -------------------
         mazes = task.test_input[:32, : task.height * task.width]
         policies = task.test_policies[:32]
-        order = random_order(input, task.height * task.width)
+        order = random_order(mazes, task.height * task.width)
         x = shuffle(mazes, order)
         x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
         output_gpt = shuffle(x, order, reorder=True)
@@ -579,8 +584,6 @@ if nb_epochs_finished >= args.nb_epochs:
 
     task.produce_results(n_epoch, model)
 
-    exit(0)
-
 ##############################
 
 for n_epoch in range(nb_epochs_finished, args.nb_epochs):