return (output - targets).abs().sum() / masks.sum()
-def oneshot(gpt, learning_rate_scheduler, task):
+def oneshot(model, learning_rate_scheduler, task):
+ t = model.training
+ model.eval()
+ mazes = task.test_input[:32].clone()
+ mazes[:, task.height * task.width :] = 0
+ output = eval_mygpt(model, mazes, prompt_len=task.height * task.width)
+ output = F.softmax(output, dim=2)
+ print(f"{output.size()=}")
+ proba_path = output[:, task.height * task.width :, 4].reshape(
+ -1, task.height, task.width
+ )
+ mazes = mazes[:, : task.height * task.width].reshape(-1, task.height, task.width)
+ # targets = targets.reshape(-1, task.height, task.width)
+ filename = f"oneshot.png"
+ maze.save_image(
+ os.path.join(args.result_dir, filename),
+ mazes=mazes,
+ score_paths=proba_path,
+ # score_truth=targets,
+ )
+ log_string(f"wrote {filename}")
+
+
+def oneshot_old(gpt, learning_rate_scheduler, task):
t = gpt.training
gpt.eval()
q = torch.arange(d)[:, None]
k = torch.arange(d)[None, :]
s = args.maze_height * args.maze_width
- # return torch.logical_and(q < k, torch.logical_or(q >= s, k >= s))
- return q < k
+ return torch.logical_and(q < k, torch.logical_or(q >= s, k >= s))
+ # return q < k
+
+def noncausal_prompt_oneshot_amm_generator(d):
+ q = torch.arange(d)[:, None]
+ k = torch.arange(d)[None, :]
+ s = args.maze_height * args.maze_width
+ return k >= s
+ # return q < k
-amm_generator = None
-if args.noncausal_prompt:
+if args.oneshot:
+ amm_generator = noncausal_prompt_oneshot_amm_generator
+elif args.noncausal_prompt:
amm_generator = noncausal_prompt_amm_generator
+else:
+ amm_generator = None
model = mygpt.MyGPT(
vocabulary_size=vocabulary_size,
######################################################################
+if args.oneshot:
+ oneshot(model, learning_rate_scheduler, task)
+ exit(0)
+
+######################################################################
+
token_count = 0
for input in task.batches(split="train"):
token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
log_string(f"saved checkpoint {checkpoint_name}")
######################################################################
-
-if args.oneshot:
- oneshot(model, learning_rate_scheduler, task)
-
-######################################################################