From 403194a404989f5f8d18c8008bdb6911144ca71e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 18 Sep 2024 09:15:27 +0200 Subject: [PATCH] Update. --- main.py | 235 ++++++++------------------------------------------------ 1 file changed, 31 insertions(+), 204 deletions(-) diff --git a/main.py b/main.py index 5f80fb5..05bb108 100755 --- a/main.py +++ b/main.py @@ -336,17 +336,6 @@ log_string(f"vocabulary_size {vocabulary_size}") ###################################################################### - -def bag_len(bag): - return sum([x.size(0) for x in bag]) - - -def bag_to_tensors(bag): - return tuple(torch.cat([x[i] for x in bag], dim=0) for i in range(len(bag[0]))) - - -###################################################################### - # If we need to move an optimizer to a different device @@ -651,22 +640,6 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): ) -###################################################################### - - -class TokenCat(nn.Module): - def __init__(self, m, n): - super().__init__() - self.m = m - self.n = n - - def forward(self, x): - u = torch.cat([x.new_zeros(x.size(0), self.n), x], dim=1) - u = self.m(u) - u = u[:, self.n :] - return u - - ###################################################################### import attae @@ -701,186 +674,13 @@ for i in range(args.nb_models): ###################################################################### -def quiz_validation_( - models, - c_quizzes, - local_device, - nb_have_to_be_correct, - nb_have_to_be_wrong, - nb_mistakes_to_be_wrong, - nb_hints, - nb_runs=1, -): - ###################################################################### - # If too many with process per-batch - - if c_quizzes.size(0) > args.inference_batch_size: - record = [] - for q, nh in zip( - c_quizzes.split(args.inference_batch_size), - nb_hints.split(args.inference_batch_size), - ): - record.append( - quiz_validation( - models=models, - c_quizzes=q, - local_device=local_device, - nb_have_to_be_correct=nb_have_to_be_correct, - nb_have_to_be_wrong=nb_have_to_be_wrong, - nb_mistakes_to_be_wrong=nb_mistakes_to_be_wrong, - nb_hints=nh, - nb_runs=nb_runs, - ) - ) - - r = [] - for k in range(len(record[0])): - r.append(torch.cat([x[k] for x in record], dim=0)) - - return tuple(r) - ###################################################################### - - record_wrong = [] - nb_correct, nb_wrong = 0, 0 - - for i, model in enumerate(models): - assert i == model.id # a bit of paranoia - model = copy.deepcopy(model).to(local_device).eval() - correct, wrong = True, False - for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]: - mask_generate = quiz_machine.make_quiz_mask( - quizzes=c_quizzes, - quad_order=("A", "f_A", "B", "f_B"), - quad_mask=quad, - ) - - sub_correct, sub_wrong = False, True - for _ in range(nb_runs): - result = diffuser.generate( - model=model, - x_0=c_quizzes, - mask_generate=mask_generate, - nb_hints=nb_hints, - ) - - nb_mistakes = (result != c_quizzes).long().sum(dim=1) - sub_correct = sub_correct | (nb_mistakes == 0) - sub_wrong = sub_wrong & (nb_mistakes >= nb_mistakes_to_be_wrong) - - correct = correct & sub_correct - wrong = wrong | sub_wrong - - record_wrong.append(wrong[:, None]) - nb_correct += correct.long() - nb_wrong += wrong.long() - - to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong) - - wrong = torch.cat(record_wrong, dim=1) - - return to_keep, nb_correct, nb_wrong, wrong - - -###################################################################### - - -def generate_c_quizzes_(models, nb, local_device=main_device): - # To be thread-safe we must make copies - - def copy_for_inference(model): - return copy.deepcopy(model).to(local_device).eval() - - quad_order = ("A", "f_A", "B", "f_B") - - template = quiz_machine.problem.create_empty_quizzes( - nb=args.inference_batch_size, quad_order=quad_order - ).to(local_device) - - wanted_nb = nb - nb_to_save = 256 - nb_c_quizzes_per_model = torch.zeros(len(models), device=local_device) - - with torch.autograd.no_grad(): - record_c_quizzes, record_agreements = [], [] - - last_log = -1 - start_time = time.perf_counter() - - while nb_c_quizzes_per_model.min() < wanted_nb: - model = copy_for_inference(models[torch.randint(len(models), (1,)).item()]) - generator_id = model.id - - mask_generate = quiz_machine.make_quiz_mask( - quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1) - ) - - c_quizzes = diffuser.generate(model, template, mask_generate) - - to_keep = quiz_machine.problem.trivial(c_quizzes) == False - c_quizzes = c_quizzes[to_keep] - - nb_hints = torch.full( - (c_quizzes.size(0),), args.nb_hints, device=c_quizzes.device - ) - - if c_quizzes.size(0) > 0: - to_keep, nb_correct, nb_wrong, record_wrong = quiz_validation( - models, - c_quizzes, - local_device, - nb_have_to_be_correct=args.nb_have_to_be_correct, - nb_have_to_be_wrong=args.nb_have_to_be_wrong, - nb_mistakes_to_be_wrong=args.nb_mistakes_to_be_wrong, - nb_hints=nb_hints, - nb_runs=args.nb_runs, - ) - - # to_keep[...]=True - - q = c_quizzes[to_keep] - - if q.size(0) > 0: - record_c_quizzes.append(q) - a = (record_wrong == False)[to_keep] - record_agreements.append(a) - nb_c_quizzes_per_model += a.long().sum(dim=0) - - duration = time.perf_counter() - start_time - nb_generated = nb_c_quizzes_per_model.min().item() - - if last_log < 0 or duration > last_log + 5: - last_log = duration - if nb_generated > 0: - if nb_generated < wanted_nb: - d = (wanted_nb - nb_generated) * duration / nb_generated - e = ( - datetime.datetime.now() + datetime.timedelta(seconds=d) - ).strftime("%a %H:%M") - else: - e = "now!" - else: - e = "???" - - log_string( - f"nb_generated {bag_len(record_c_quizzes)} model {generator_id} (finishes {e} -- {int((nb_generated * 3600)/duration)}/h)" - ) - - duration = time.perf_counter() - start_time - - log_string(f"generate_c_quizz_speed {int(3600 * wanted_nb / duration)}/h") - - c_quizzes = torch.cat(record_c_quizzes, dim=0) - agreements = torch.cat(record_agreements, dim=0) - - return c_quizzes.to("cpu"), agreements.to("cpu") - - -###################################################################### - - def generate_c_quizzes(models, nb, local_device=main_device): record = [] nb_validated = 0 + + start_time = time.perf_counter() + last_log = -1 + while nb_validated < nb: model = models[torch.randint(len(models), (1,)).item()] model = copy.deepcopy(model).to(local_device).eval() @@ -911,6 +711,33 @@ def generate_c_quizzes(models, nb, local_device=main_device): log_string(f"generate_c_quizzes {nb_validated}") + ##################### + + duration = time.perf_counter() - start_time + + if last_log < 0 or duration > last_log + 10: + last_log = duration + if nb_validated > 0: + if nb_validated < wanted_nb: + d = (wanted_nb - nb_validated) * duration / nb_validated + e = ( + datetime.datetime.now() + datetime.timedelta(seconds=d) + ).strftime("%a %H:%M") + else: + e = "now!" + else: + e = "???" + + log_string( + f"nb_validated {nb_validated} model {generator_id} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h)" + ) + + ##################### + + duration = time.perf_counter() - start_time + + log_string(f"generate_c_quizz_speed {int(3600 * wanted_nb / duration)}/h") + return torch.cat(record) -- 2.39.5