(("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_modifier_hot),
(("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_modifier_cold),
(("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold),
- (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold),
+ # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold),
]
######################################################################
output = model(mygpt.BracketedSequence(input[:, : 3 * L])).x
dist = torch.distributions.categorical.Categorical(logits=output)
- input[:, 3 * L :] = dist.sample()
+ input[:, 3 * L + 1 :] = dist.sample()[:, 1:]
+
+ problem.save_quizzes_as_image(
+ args.result_dir,
+ f"thinker_prediction_{n_epoch:04d}.png",
+ quizzes=input,
+ # predicted_parts=predicted_parts,
+ # correct_parts=correct_parts,
+ )
######################################################################