parser.add_argument("--nb_gpts", type=int, default=5)
-parser.add_argument("--check", action="store_true", default=False)
+parser.add_argument("--dirty_debug", action="store_true", default=False)
######################################################################
######################################################################
-if args.check:
- args.nb_train_samples = 25000
- args.nb_test_samples = 1000
+if args.dirty_debug:
+ args.nb_train_samples = 2500
+ args.nb_test_samples = 100
if args.physical_batch_size is None:
args.physical_batch_size = args.batch_size
):
kept = []
- sum_logits = 0
+ sum_logits, sum_nb_quizzes = 0, 0
while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test:
nb_to_generate = 4 * (nb_for_train + nb_for_test)
- new_quizzes, nb_correct, _sum_logits = task.create_new_quizzes(
+ new_quizzes, nb_correct, average_logits = task.create_new_quizzes(
n_epoch=n_epoch,
result_dir=args.result_dir,
logger=log_string,
desired_average_logits=desired_average_logits,
)
- sum_logits += _sum_logits
+ sum_logits += new_quizzes.size(0) * average_logits
+ sum_nb_quizzes += new_quizzes.size(0)
to_keep = new_quizzes[nb_correct == len(other_models) - 1]
+
+ if args.dirty_debug:
+ to_keep = new_quizzes
+
log_string(
f"keep {to_keep.size(0)}/{new_quizzes.size(0)} quizzes ({to_keep.size(0)*100/new_quizzes.size(0):.02f}%)"
)
+
kept.append(to_keep)
new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test]
log_string,
)
- return sum_logits / new_quizzes.size(0)
+ return sum_logits / sum_nb_quizzes
######################################################################
nb_new_quizzes_for_train = 1000
nb_new_quizzes_for_test = 100
-if args.check:
+if args.dirty_debug:
accuracy_to_make_quizzes = 0.0
nb_new_quizzes_for_train = 100
nb_new_quizzes_for_test = 10
self,
input,
ar_mask,
+ summed_logits,
temperature=1.0,
deterministic_synthesis=False,
forbidden_tokens=None,
forced_biases=None,
):
- sum_logits = 0
-
to_generate = (ar_mask.sum(0) > 0).nonzero()
if to_generate.min() > 0:
logits = output[:, s]
- logits = logits.log_softmax(dim=1) / temperature
+ logits = (logits / temperature).log_softmax(dim=-1)
if forbidden_tokens is not None:
logits = logits.masked_fill(forbidden_tokens, float("-inf"))
logits = logits + forced_biases[None, :]
if deterministic_synthesis:
- t_next = logits.argmax(1)
+ t_next = logits.argmax(-1)
else:
dist = torch.distributions.categorical.Categorical(logits=logits)
t_next = dist.sample()
- sum_logits += logits.log_softmax(dim=1)[
- torch.arange(t_next.size(0)), t_next
- ].sum()
+ if summed_logits is not None:
+ summed_logits += logits[torch.arange(t_next.size(0)), t_next].sum(
+ dim=-1
+ )
input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
- return sum_logits
-
def record_attention(self, v=True):
for m in self.modules():
if isinstance(m, QKVAttention):
batch_size,
input,
ar_mask,
+ summed_logits,
temperature,
deterministic_synthesis,
forbidden_tokens=None,
total=(input.size(0) + batch_size - 1) // batch_size,
)
- sum_logits = 0
-
with torch.autograd.no_grad():
t = model.training
model.eval()
for input, ar_mask in batches:
- sum_logits += model.masked_inplace_autoregression(
+ model.masked_inplace_autoregression(
input=input,
ar_mask=ar_mask,
+ summed_logits=summed_logits,
temperature=temperature,
deterministic_synthesis=deterministic_synthesis,
forbidden_tokens=forbidden_tokens,
model.train(t)
- return sum_logits
-
######################################################################
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
)
ar_mask = torch.full(quizzes.size(), 1, device=self.device)
+ summed_logits = torch.empty(nb, device=self.device)
temperature = 1
d_temperature = 1
while True:
- sum_logits = masked_inplace_autoregression(
+ summed_logits[...] = 0
+
+ masked_inplace_autoregression(
model=model,
batch_size=self.batch_size,
input=quizzes,
ar_mask=ar_mask,
+ summed_logits=summed_logits,
temperature=temperature,
deterministic_synthesis=False,
progress_bar_desc="creating quizzes",
device=self.device,
)
- average_logits = sum_logits / quizzes.size(0)
+ average_logits = summed_logits.mean()
logger(f"{average_logits=} {desired_average_logits=}")
break
# Oh man that's ugly
- if average_logits > desired_average_logits:
+ if average_logits < desired_average_logits:
if d_temperature < 0:
d_temperature *= -0.5
temperature += d_temperature
- else:
+ elif average_logits > desired_average_logits * 0.95:
if d_temperature > 0:
d_temperature *= -0.5
temperature += d_temperature
+ else:
+ break
logger(f"chaging temperature to {temperature}")
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=True,
progress_bar_desc="solving quizzes",
batch_size=self.batch_size,
input=reverse_result,
ar_mask=ar_mask,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=True,
progress_bar_desc="solving reversed quizzes",
# for k in nb_correct:
# f.write(f"{k}\n")
- return quizzes, nb_correct.sum(dim=0), sum_logits
+ return quizzes, nb_correct.sum(dim=0), summed_logits.mean()