parser.add_argument("--maze_nb_walls", type=int, default=15)
+##############################
+# one-shot prediction
+
parser.add_argument("--oneshot", action="store_true", default=False)
parser.add_argument("--oneshot_input", type=str, default="head")
acc_train_loss, nb_train_samples = 0, 0
for mazes, policies in task.policy_batches(split="train"):
- ####
- # print(f'{mazes.size()=} {policies.size()=}')
- # s = maze.stationary_densities(
- # exit(0)
- ####
output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x
output = model(output_gpt)
##############################
-if args.oneshot:
- oneshot(model, task)
- exit(0)
-
-##############################
-
if nb_epochs_finished >= args.nb_epochs:
n_epoch = nb_epochs_finished
train_perplexity = compute_perplexity(model, split="train")
log_string(f"saved checkpoint {checkpoint_name}")
######################################################################
+
+if args.oneshot:
+ oneshot(model, task)
+
+######################################################################