+ ar_mask = torch.full(quizzes.size(), 1, device=self.device)
+ summed_logits = torch.empty(nb, device=self.device)
+
+ temperature = 1
+ d_temperature = 1
+
+ while True:
+ summed_logits[...] = 0
+
+ masked_inplace_autoregression(
+ model=model,
+ batch_size=self.batch_size,
+ input=quizzes,
+ ar_mask=ar_mask,
+ summed_logits=summed_logits,
+ temperature=temperature,
+ deterministic_synthesis=False,
+ progress_bar_desc="creating quizzes",
+ device=self.device,
+ )
+
+ average_logits = summed_logits.mean()
+
+ logger(f"{average_logits=} {desired_average_logits=}")
+
+ if desired_average_logits is None:
+ break
+
+ # Oh man that's ugly
+ if average_logits < desired_average_logits * 1.1:
+ if d_temperature > 0:
+ d_temperature *= -0.5
+ temperature += d_temperature
+ elif average_logits > desired_average_logits:
+ if d_temperature < 0:
+ d_temperature *= -0.5
+ temperature += d_temperature
+ else:
+ break
+
+ logger(f"changing temperature to {temperature}")
+
+ ###############################################################
+ # Create the reverse quizzes
+
+ l = self.height * self.width
+ direction = quizzes[:, l : l + 1]
+ direction = world.token_forward * (
+ direction == world.token_backward
+ ) + world.token_backward * (direction == world.token_forward)
+ reverse_quizzes = torch.cat(
+ [quizzes[:, l + 1 :], direction, quizzes[:, :l]], dim=1