batch_size,
input,
ar_mask,
+ summed_logits,
temperature,
deterministic_synthesis,
forbidden_tokens=None,
total=(input.size(0) + batch_size - 1) // batch_size,
)
- sum_logits = 0
-
with torch.autograd.no_grad():
t = model.training
model.eval()
for input, ar_mask in batches:
- sum_logits += model.masked_inplace_autoregression(
+ model.masked_inplace_autoregression(
input=input,
ar_mask=ar_mask,
+ summed_logits=summed_logits,
temperature=temperature,
deterministic_synthesis=deterministic_synthesis,
forbidden_tokens=forbidden_tokens,
model.train(t)
- return sum_logits
-
######################################################################
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
)
ar_mask = torch.full(quizzes.size(), 1, device=self.device)
+ summed_logits = torch.empty(nb, device=self.device)
temperature = 1
d_temperature = 1
while True:
- sum_logits = masked_inplace_autoregression(
+ summed_logits[...] = 0
+
+ masked_inplace_autoregression(
model=model,
batch_size=self.batch_size,
input=quizzes,
ar_mask=ar_mask,
+ summed_logits=summed_logits,
temperature=temperature,
deterministic_synthesis=False,
progress_bar_desc="creating quizzes",
device=self.device,
)
- average_logits = sum_logits / quizzes.size(0)
+ average_logits = summed_logits.mean()
logger(f"{average_logits=} {desired_average_logits=}")
break
# Oh man that's ugly
- if average_logits > desired_average_logits:
- if d_temperature < 0:
+ if average_logits < desired_average_logits * 1.1:
+ if d_temperature > 0:
d_temperature *= -0.5
temperature += d_temperature
- else:
- if d_temperature > 0:
+ elif average_logits > desired_average_logits:
+ if d_temperature < 0:
d_temperature *= -0.5
temperature += d_temperature
- logger(f"chaging temperature to {temperature}")
+ else:
+ break
+
+ logger(f"changing temperature to {temperature}")
###############################################################
# Create the reverse quizzes
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=True,
progress_bar_desc="solving quizzes",
batch_size=self.batch_size,
input=reverse_result,
ar_mask=ar_mask,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=True,
progress_bar_desc="solving reversed quizzes",
# for k in nb_correct:
# f.write(f"{k}\n")
- return quizzes, nb_correct.sum(dim=0), sum_logits
+ return quizzes, nb_correct.sum(dim=0), summed_logits.mean()