projects
/
beaver.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
c1fd5d2
)
Update
author
François Fleuret
<francois@fleuret.org>
Tue, 21 Mar 2023 05:33:47 +0000
(06:33 +0100)
committer
François Fleuret
<francois@fleuret.org>
Tue, 21 Mar 2023 05:33:47 +0000
(06:33 +0100)
beaver.py
patch
|
blob
|
history
diff --git
a/beaver.py
b/beaver.py
index
5916215
..
f5bd924
100755
(executable)
--- a/
beaver.py
+++ b/
beaver.py
@@
-79,6
+79,9
@@
parser.add_argument("--maze_width", type=int, default=21)
parser.add_argument("--maze_nb_walls", type=int, default=15)
parser.add_argument("--maze_nb_walls", type=int, default=15)
+##############################
+# one-shot prediction
+
parser.add_argument("--oneshot", action="store_true", default=False)
parser.add_argument("--oneshot_input", type=str, default="head")
parser.add_argument("--oneshot", action="store_true", default=False)
parser.add_argument("--oneshot_input", type=str, default="head")
@@
-224,11
+227,6
@@
def oneshot(gpt, task):
acc_train_loss, nb_train_samples = 0, 0
for mazes, policies in task.policy_batches(split="train"):
acc_train_loss, nb_train_samples = 0, 0
for mazes, policies in task.policy_batches(split="train"):
- ####
- # print(f'{mazes.size()=} {policies.size()=}')
- # s = maze.stationary_densities(
- # exit(0)
- ####
output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x
output = model(output_gpt)
output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x
output = model(output_gpt)
@@
-535,12
+533,6
@@
log_string(f"learning_rate_schedule {learning_rate_schedule}")
##############################
##############################
-if args.oneshot:
- oneshot(model, task)
- exit(0)
-
-##############################
-
if nb_epochs_finished >= args.nb_epochs:
n_epoch = nb_epochs_finished
train_perplexity = compute_perplexity(model, split="train")
if nb_epochs_finished >= args.nb_epochs:
n_epoch = nb_epochs_finished
train_perplexity = compute_perplexity(model, split="train")
@@
-608,3
+600,8
@@
for n_epoch in range(nb_epochs_finished, args.nb_epochs):
log_string(f"saved checkpoint {checkpoint_name}")
######################################################################
log_string(f"saved checkpoint {checkpoint_name}")
######################################################################
+
+if args.oneshot:
+ oneshot(model, task)
+
+######################################################################