batch_size,
input,
ar_mask,
+ summed_logits,
temperature,
deterministic_synthesis,
forbidden_tokens=None,
t = model.training
model.eval()
- sum_logits = 0
-
for input, ar_mask in batches:
- sum_logits += model.masked_inplace_autoregression(
+ model.masked_inplace_autoregression(
input=input,
ar_mask=ar_mask,
+ summed_logits=summed_logits,
temperature=temperature,
deterministic_synthesis=deterministic_synthesis,
forbidden_tokens=forbidden_tokens,
model.train(t)
- return sum_logits
-
######################################################################
self.nb_batch_samples_world = input.size(0)
self.nb_batch_samples_quizzes = 0
+ # Shuffle
+ input = input[torch.randperm(input.size(0))]
+
if desc is None:
desc = f"epoch-{split}"
for batch in tqdm.tqdm(
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
quizzes = torch.empty(
nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
)
+
ar_mask = torch.full(quizzes.size(), 1, device=self.device)
+ summed_logits = torch.empty(nb, device=self.device)
- sum_logits = masked_inplace_autoregression(
- model=model,
- batch_size=self.batch_size,
- input=quizzes,
- ar_mask=ar_mask,
- temperature=1.0,
- deterministic_synthesis=False,
- progress_bar_desc="creating quizzes",
- device=self.device,
- )
+ temperature = 1
+ d_temperature = 1
- average_logits = sum_logits / quizzes.numel()
+ while True:
+ summed_logits[...] = 0
- if desired_average_logits is not None:
- temperature = average_logits / desired_average_logits
masked_inplace_autoregression(
model=model,
batch_size=self.batch_size,
input=quizzes,
ar_mask=ar_mask,
+ summed_logits=summed_logits,
temperature=temperature,
deterministic_synthesis=False,
progress_bar_desc="creating quizzes",
device=self.device,
)
+ average_logits = summed_logits.mean()
+
+ logger(f"{average_logits=} {desired_average_logits=}")
+
+ if desired_average_logits is None:
+ break
+
+ # Oh man that's ugly
+ if average_logits < desired_average_logits * 1.1:
+ if d_temperature > 0:
+ d_temperature *= -0.5
+ temperature += d_temperature
+ elif average_logits > desired_average_logits:
+ if d_temperature < 0:
+ d_temperature *= -0.5
+ temperature += d_temperature
+ else:
+ break
+
+ logger(f"changing temperature to {temperature}")
+
###############################################################
# Create the reverse quizzes
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=True,
progress_bar_desc="solving quizzes",
batch_size=self.batch_size,
input=reverse_result,
ar_mask=ar_mask,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=True,
progress_bar_desc="solving reversed quizzes",
nb_correct = torch.cat(nb_correct, dim=0)
- filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
- with open(filename, "w") as f:
- for k in nb_correct:
- f.write(f"{k}\n")
+ # filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
+ # with open(filename, "w") as f:
+ # for k in nb_correct:
+ # f.write(f"{k}\n")
- return quizzes, nb_correct.sum(dim=0), average_logits
+ return quizzes, nb_correct.sum(dim=0), summed_logits.mean()