total=(input.size(0) + batch_size - 1) // batch_size,
)
+ sum_logits = 0
+
with torch.autograd.no_grad():
t = model.training
model.eval()
- sum_logits = 0
-
for input, ar_mask in batches:
sum_logits += model.masked_inplace_autoregression(
input=input,
model.train(t)
- return sum_logits
+ return sum_logits
######################################################################
quizzes = torch.empty(
nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
)
- ar_mask = torch.full(quizzes.size(), 1, device=self.device)
-
- sum_logits = masked_inplace_autoregression(
- model=model,
- batch_size=self.batch_size,
- input=quizzes,
- ar_mask=ar_mask,
- temperature=1.0,
- deterministic_synthesis=False,
- progress_bar_desc="creating quizzes",
- device=self.device,
- )
-
- # Should not be necessary though, the autoregression is done
- # in eval mode
- sum_logits = sum_logits.detach()
- average_logits = sum_logits / quizzes.numel()
+ ar_mask = torch.full(quizzes.size(), 1, device=self.device)
- # It's a bit brutal to do it twice, we should probably have a
- # moving average and apply it right away
+ temperature = 1
+ d_temperature = 1
- if desired_average_logits is not None:
- temperature = average_logits / desired_average_logits
- masked_inplace_autoregression(
+ while True:
+ sum_logits = masked_inplace_autoregression(
model=model,
batch_size=self.batch_size,
input=quizzes,
device=self.device,
)
+ average_logits = sum_logits / quizzes.size(0)
+
+ logger(f"{average_logits=} {desired_average_logits=}")
+
+ if desired_average_logits is None:
+ break
+
+ # Oh man that's ugly
+ if average_logits > desired_average_logits:
+ if d_temperature < 0:
+ d_temperature *= -0.5
+ temperature += d_temperature
+ else:
+ if d_temperature > 0:
+ d_temperature *= -0.5
+ temperature += d_temperature
+
+ logger(f"chaging temperature to {temperature}")
+
###############################################################
# Create the reverse quizzes
nb_correct = torch.cat(nb_correct, dim=0)
- filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
- with open(filename, "w") as f:
- for k in nb_correct:
- f.write(f"{k}\n")
+ # filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
+ # with open(filename, "w") as f:
+ # for k in nb_correct:
+ # f.write(f"{k}\n")
- return quizzes, nb_correct.sum(dim=0), average_logits
+ return quizzes, nb_correct.sum(dim=0), sum_logits