def masked_cross_entropy(output, targets, masks):
loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none")
- return (loss_per_token * masks).sum() / masks.expand_as(loss_per_token).sum()
+ return (loss_per_token * masks).mean()
######################################################################
def generate(model, nb, local_device=main_device):
+ model.eval().to(local_device)
+
all_input = quiz_machine.pure_noise(nb, local_device)
all_masks = all_input.new_full(all_input.size(), 1)
correct_parts=correct_parts[:128],
)
- model.test_accuracy = correct.sum() / quizzes.size(0)
+ nb_correct, nb_total = correct.sum(), quizzes.size(0)
+ model.test_accuracy = nb_correct / nb_total
+
+ log_string(
+ f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({model.test_accuracy:.02f}%)"
+ )
# generate
)
+######################################################################
+
+
+class TokenCat(nn.Module):
+ def __init__(self, m, n):
+ super().__init__()
+ self.m = m
+ self.n = n
+
+ def forward(self, x):
+ u = torch.cat([x.new_zeros(x.size(0), self.n), x], dim=1)
+ u = self.m(u)
+ u = u[:, self.n :]
+ return u
+
+
######################################################################
import attae
dropout=args.dropout,
).to(main_device)
+ # if i < args.nb_models//2:
+ # model = TokenCat(model, 10)
+
# model = torch.compile(model)
model.id = i
######################################################################
-def generate_c_quizzes(models, nb, local_device=main_device):
+def generate_c_quizzes_(models, nb, local_device=main_device):
# To be thread-safe we must make copies
def copy_for_inference(model):
######################################################################
+def generate_c_quizzes(models, nb, local_device=main_device):
+ record = []
+ nb_validated = 0
+ while nb_validated < nb:
+ model = models[torch.randint(len(models), (1,)).item()]
+ model = copy.deepcopy(model).to(local_device).eval()
+ generator_id = model.id
+
+ c_quizzes = generate(
+ moel=copy_for_inference(model),
+ nb=args.physical_batch_size,
+ local_device=local_device,
+ )
+
+ nb_correct, nb_wrong = 0, 0
+ for i, model in enumerate(models):
+ model = copy.deepcopy(model).to(local_device).eval()
+ result = predict_full(model, c_quizzes, local_device=local_device)
+ nb_mistakes = (result != c_quizzes).long().sum(dim=1)
+ nb_correct += (nb_mistakes == 0).long()
+ nb_wrong += nb_mistakes >= args.nb_mistakes_to_be_wrong
+
+ to_keep = (nb_correct >= args.nb_have_to_be_correct) & (
+ nb_wrong >= args.nb_have_to_be_wrong
+ )
+
+ nb_validated += to_keep.long().sum()
+ record.append(c_quizzes[to_keep])
+
+ return torch.cat(record)
+
+
+######################################################################
+
+
def save_c_quizzes_with_scores(models, c_quizzes, filename, solvable_only=False):
l = []
ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
weakest_models = ranked_models[: len(gpus)]
- # None if c_quizzes is None else c_quizzes[agreements[:, model.id]],
-
multithread_execution(
one_complete_epoch,
[(model, n_epoch, c_quizzes, gpu) for model, gpu in zip(weakest_models, gpus)],
)
- # --------------------------------------------------------------------
-
save_models(models)
+ # --------------------------------------------------------------------
+
duration = time.perf_counter() - start_time
str_duration = ""
if duration >= 60: