From: François Fleuret Date: Wed, 22 Mar 2023 15:26:33 +0000 (+0100) Subject: Update X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=9e731cac6db44cd6bf2f58e1a867e287e330f6dd;p=beaver.git Update --- diff --git a/beaver.py b/beaver.py index 6ec0fb2..5a15aee 100755 --- a/beaver.py +++ b/beaver.py @@ -245,7 +245,7 @@ def oneshot(gpt, task): acc_train_loss, nb_train_samples = 0, 0 for mazes, policies in task.policy_batches(split="train"): - order = random_order(input, task.height * task.width) + order = random_order(mazes, task.height * task.width) x = shuffle(mazes, order) x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x output_gpt = shuffle(x, order, reorder=True) @@ -261,7 +261,7 @@ def oneshot(gpt, task): acc_test_loss, nb_test_samples = 0, 0 for mazes, policies in task.policy_batches(split="test"): - order = random_order(input, task.height * task.width) + order = random_order(mazes, task.height * task.width) x = shuffle(mazes, order) x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x output_gpt = shuffle(x, order, reorder=True) @@ -277,7 +277,7 @@ def oneshot(gpt, task): # ------------------- mazes = task.test_input[:32, : task.height * task.width] policies = task.test_policies[:32] - order = random_order(input, task.height * task.width) + order = random_order(mazes, task.height * task.width) x = shuffle(mazes, order) x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x output_gpt = shuffle(x, order, reorder=True)