projects
/
beaver.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
2cd3f15
)
Update
author
François Fleuret
<francois@fleuret.org>
Wed, 22 Mar 2023 15:26:33 +0000
(16:26 +0100)
committer
François Fleuret
<francois@fleuret.org>
Wed, 22 Mar 2023 15:26:33 +0000
(16:26 +0100)
beaver.py
patch
|
blob
|
history
diff --git
a/beaver.py
b/beaver.py
index
6ec0fb2
..
5a15aee
100755
(executable)
--- a/
beaver.py
+++ b/
beaver.py
@@
-245,7
+245,7
@@
def oneshot(gpt, task):
acc_train_loss, nb_train_samples = 0, 0
for mazes, policies in task.policy_batches(split="train"):
acc_train_loss, nb_train_samples = 0, 0
for mazes, policies in task.policy_batches(split="train"):
- order = random_order(
input
, task.height * task.width)
+ order = random_order(
mazes
, task.height * task.width)
x = shuffle(mazes, order)
x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
output_gpt = shuffle(x, order, reorder=True)
x = shuffle(mazes, order)
x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
output_gpt = shuffle(x, order, reorder=True)
@@
-261,7
+261,7
@@
def oneshot(gpt, task):
acc_test_loss, nb_test_samples = 0, 0
for mazes, policies in task.policy_batches(split="test"):
acc_test_loss, nb_test_samples = 0, 0
for mazes, policies in task.policy_batches(split="test"):
- order = random_order(
input
, task.height * task.width)
+ order = random_order(
mazes
, task.height * task.width)
x = shuffle(mazes, order)
x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
output_gpt = shuffle(x, order, reorder=True)
x = shuffle(mazes, order)
x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
output_gpt = shuffle(x, order, reorder=True)
@@
-277,7
+277,7
@@
def oneshot(gpt, task):
# -------------------
mazes = task.test_input[:32, : task.height * task.width]
policies = task.test_policies[:32]
# -------------------
mazes = task.test_input[:32, : task.height * task.width]
policies = task.test_policies[:32]
- order = random_order(
input
, task.height * task.width)
+ order = random_order(
mazes
, task.height * task.width)
x = shuffle(mazes, order)
x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
output_gpt = shuffle(x, order, reorder=True)
x = shuffle(mazes, order)
x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
output_gpt = shuffle(x, order, reorder=True)