From 9e731cac6db44cd6bf2f58e1a867e287e330f6dd Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 22 Mar 2023 16:26:33 +0100 Subject: [PATCH] Update --- beaver.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/beaver.py b/beaver.py index 6ec0fb2..5a15aee 100755 --- a/beaver.py +++ b/beaver.py @@ -245,7 +245,7 @@ def oneshot(gpt, task): acc_train_loss, nb_train_samples = 0, 0 for mazes, policies in task.policy_batches(split="train"): - order = random_order(input, task.height * task.width) + order = random_order(mazes, task.height * task.width) x = shuffle(mazes, order) x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x output_gpt = shuffle(x, order, reorder=True) @@ -261,7 +261,7 @@ def oneshot(gpt, task): acc_test_loss, nb_test_samples = 0, 0 for mazes, policies in task.policy_batches(split="test"): - order = random_order(input, task.height * task.width) + order = random_order(mazes, task.height * task.width) x = shuffle(mazes, order) x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x output_gpt = shuffle(x, order, reorder=True) @@ -277,7 +277,7 @@ def oneshot(gpt, task): # ------------------- mazes = task.test_input[:32, : task.height * task.width] policies = task.test_policies[:32] - order = random_order(input, task.height * task.width) + order = random_order(mazes, task.height * task.width) x = shuffle(mazes, order) x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x output_gpt = shuffle(x, order, reorder=True) -- 2.20.1