######################################################################
if args.check:
- args.nb_train_samples = 2500
- args.nb_test_samples = 100
+ args.nb_train_samples = 25000
+ args.nb_test_samples = 1000
if args.physical_batch_size is None:
args.physical_batch_size = args.batch_size
desired_average_logits=None,
):
kept = []
- nb_generated_tokens, sum_logits = 0, 0
+
+ sum_logits = 0
while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test:
nb_to_generate = 4 * (nb_for_train + nb_for_test)
- new_quizzes, nb_correct, average_logits = task.create_new_quizzes(
+
+ new_quizzes, nb_correct, _sum_logits = task.create_new_quizzes(
n_epoch=n_epoch,
result_dir=args.result_dir,
logger=log_string,
desired_average_logits=desired_average_logits,
)
- nb_generated_tokens += new_quizzes.numel()
- sum_logits += average_logits * new_quizzes.numel()
+ sum_logits += _sum_logits
to_keep = new_quizzes[nb_correct == len(other_models) - 1]
log_string(
log_string,
)
- return sum_logits / nb_generated_tokens
+ return sum_logits / new_quizzes.size(0)
######################################################################
if args.check:
accuracy_to_make_quizzes = 0.0
- nb_new_quizzes_for_train = 10
+ nb_new_quizzes_for_train = 100
nb_new_quizzes_for_test = 10
desired_average_logits = None
forced_biases=None,
):
sum_logits = 0
+
to_generate = (ar_mask.sum(0) > 0).nonzero()
+
if to_generate.min() > 0:
self(
BracketedSequence(input, 0, to_generate.min())
) # Needed to initialize the model's cache
for s in range(to_generate.min(), to_generate.max() + 1):
output = self(BracketedSequence(input, s, 1)).x
+
logits = output[:, s]
- logits = logits.log_softmax(dim=-1) / temperature
+ logits = logits.log_softmax(dim=1) / temperature
if forbidden_tokens is not None:
logits = logits.masked_fill(forbidden_tokens, float("-inf"))
else:
dist = torch.distributions.categorical.Categorical(logits=logits)
t_next = dist.sample()
- sum_logits += logits.log_softmax(dim=-1)[
+ sum_logits += logits.log_softmax(dim=1)[
torch.arange(t_next.size(0)), t_next
].sum()
+
input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
return sum_logits
total=(input.size(0) + batch_size - 1) // batch_size,
)
+ sum_logits = 0
+
with torch.autograd.no_grad():
t = model.training
model.eval()
- sum_logits = 0
-
for input, ar_mask in batches:
sum_logits += model.masked_inplace_autoregression(
input=input,
model.train(t)
- return sum_logits
+ return sum_logits
######################################################################
quizzes = torch.empty(
nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
)
- ar_mask = torch.full(quizzes.size(), 1, device=self.device)
- sum_logits = masked_inplace_autoregression(
- model=model,
- batch_size=self.batch_size,
- input=quizzes,
- ar_mask=ar_mask,
- temperature=1.0,
- deterministic_synthesis=False,
- progress_bar_desc="creating quizzes",
- device=self.device,
- )
-
- # Should not be necessary though, the autoregression is done
- # in eval mode
- sum_logits = sum_logits.detach()
-
- average_logits = sum_logits / quizzes.numel()
+ ar_mask = torch.full(quizzes.size(), 1, device=self.device)
- # It's a bit brutal to do it twice, we should probably have a
- # moving average and apply it right away
+ temperature = 1
+ d_temperature = 1
- if desired_average_logits is not None:
- temperature = average_logits / desired_average_logits
- masked_inplace_autoregression(
+ while True:
+ sum_logits = masked_inplace_autoregression(
model=model,
batch_size=self.batch_size,
input=quizzes,
device=self.device,
)
+ average_logits = sum_logits / quizzes.size(0)
+
+ logger(f"{average_logits=} {desired_average_logits=}")
+
+ if desired_average_logits is None:
+ break
+
+ # Oh man that's ugly
+ if average_logits > desired_average_logits:
+ if d_temperature < 0:
+ d_temperature *= -0.5
+ temperature += d_temperature
+ else:
+ if d_temperature > 0:
+ d_temperature *= -0.5
+ temperature += d_temperature
+ logger(f"chaging temperature to {temperature}")
+
###############################################################
# Create the reverse quizzes
nb_correct = torch.cat(nb_correct, dim=0)
- filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
- with open(filename, "w") as f:
- for k in nb_correct:
- f.write(f"{k}\n")
+ # filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
+ # with open(filename, "w") as f:
+ # for k in nb_correct:
+ # f.write(f"{k}\n")
- return quizzes, nb_correct.sum(dim=0), average_logits
+ return quizzes, nb_correct.sum(dim=0), sum_logits