(("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot),
(("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold),
(("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
+ # (("B", "f_B", "A", "f_A"), (0, 0, 1, 1), model_transformer_cold),
+ # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
]
######################################################################
def save_additional_results(models, science_w_quizzes):
for model in models:
+ recorder = []
+
c_quizzes = quiz_machine.generate_c_quizzes(
- 128, model_for_generation=model, procedure=c_quizzes_procedure
+ 32,
+ model_for_generation=model,
+ procedure=c_quizzes_procedure,
+ recorder=recorder,
)
+ c_quizzes = torch.cat([c[:, None, :] for c, _, in recorder], dim=1)
+ predicted_parts = torch.cat([t[:, None, :] for _, t in recorder], dim=1)
+ nrow = c_quizzes.size(1)
+ c_quizzes = c_quizzes.reshape(-1, c_quizzes.size(-1))
+ predicted_parts = predicted_parts.reshape(-1, predicted_parts.size(-1))
+
+ filename = f"non_validated_{n_epoch:04d}_{model.id:02d}.png"
quiz_machine.problem.save_quizzes_as_image(
args.result_dir,
- f"non_validated_{n_epoch:04d}_{model.id:02d}.png",
- c_quizzes,
+ filename,
+ quizzes=c_quizzes,
+ predicted_parts=predicted_parts,
+ nrow=nrow,
)
+ log_string(f"wrote {filename}")
######################################################################
######################################################################
- def generate_c_quizzes(self, nb, model_for_generation, procedure, to_recycle=None):
+ def generate_c_quizzes(
+ self, nb, model_for_generation, procedure, to_recycle=None, recorder=None
+ ):
seq_logproba = torch.zeros(nb, device=self.device)
c_quizzes = None
model_for_generation.reset_transformations()
+ if recorder is not None:
+ x = c_quizzes.clone()
+ t = torch.tensor(m, device=x.device)[None, :].expand(x.size(0), -1)
+ recorder.append(
+ self.problem.reconfigure([x, t], ("A", "f_A", "B", "f_B"))
+ )
+
if to_recycle is not None and to_recycle.size(0) > 0:
to_recycle = self.problem.reconfigure(to_recycle, s)
c_quizzes[: to_recycle.size(0)] = to_recycle