Update
authorFrançois Fleuret <francois@fleuret.org>
Wed, 22 Mar 2023 15:26:33 +0000 (16:26 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 22 Mar 2023 15:26:33 +0000 (16:26 +0100)
beaver.py

index 6ec0fb2..5a15aee 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -245,7 +245,7 @@ def oneshot(gpt, task):
 
         acc_train_loss, nb_train_samples = 0, 0
         for mazes, policies in task.policy_batches(split="train"):
-            order = random_order(input, task.height * task.width)
+            order = random_order(mazes, task.height * task.width)
             x = shuffle(mazes, order)
             x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
             output_gpt = shuffle(x, order, reorder=True)
@@ -261,7 +261,7 @@ def oneshot(gpt, task):
 
         acc_test_loss, nb_test_samples = 0, 0
         for mazes, policies in task.policy_batches(split="test"):
-            order = random_order(input, task.height * task.width)
+            order = random_order(mazes, task.height * task.width)
             x = shuffle(mazes, order)
             x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
             output_gpt = shuffle(x, order, reorder=True)
@@ -277,7 +277,7 @@ def oneshot(gpt, task):
         # -------------------
         mazes = task.test_input[:32, : task.height * task.width]
         policies = task.test_policies[:32]
-        order = random_order(input, task.height * task.width)
+        order = random_order(mazes, task.height * task.width)
         x = shuffle(mazes, order)
         x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
         output_gpt = shuffle(x, order, reorder=True)