self.save_image(
result[:96],
result_dir,
- f"world_result_{n_epoch:04d}_{model.id:02d}.png",
+ f"world_prediction_{n_epoch:04d}_{model.id:02d}.png",
logger,
)
device=self.device,
)
- nb_correct += (new_quizzes == result).long().min(dim=-1).values
+ l = self.height * self.width
+ direction = new_quizzes[:, l : l + 1]
+ direction = world.token_forward * (
+ direction == world.token_backward
+ ) + world.token_backward * (direction == world.token_forward)
+ inverted_quizzes = torch.cat(
+ [new_quizzes[:, l + 1 :], direction, new_quizzes[:, :l]], dim=1
+ )
+
+ inverted_result = inverted_quizzes.clone()
+
+ masked_inplace_autoregression(
+ m,
+ self.batch_size,
+ inverted_result,
+ ar_mask,
+ deterministic_synthesis=True,
+ progress_bar_desc="solving reverse quizzes",
+ device=self.device,
+ )
+
+ nb_correct += (new_quizzes == result).long().min(dim=-1).values * (
+ inverted_quizzes == inverted_result
+ ).long().min(dim=-1).values
return new_quizzes, nb_correct