From ac7b6f429faff0844dfe3eec3eb286af563663a1 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 31 Aug 2024 18:53:34 +0200 Subject: [PATCH] Update. --- main.py | 63 +++++++++++++++++++++++++++++++-------------------------- 1 file changed, 34 insertions(+), 29 deletions(-) diff --git a/main.py b/main.py index c11f5c2..ab625cc 100755 --- a/main.py +++ b/main.py @@ -313,7 +313,7 @@ log_string(f"vocabulary_size {vocabulary_size}") def bag_len(bag): - return sum([x[0].size(0) for x in bag]) + return sum([x.size(0) for x in bag]) def bag_to_tensors(bag): @@ -1033,8 +1033,6 @@ def ae_generate(model, input, mask_generate, noise_proba, nb_iterations_max=50): changed = True for it in range(nb_iterations_max): - print(f"{it=} {nb_iterations_max=}") - input_with_mask = NTC_channel_cat(input, mask_generate) logits = model(input_with_mask) dist = torch.distributions.categorical.Categorical(logits=logits) @@ -1260,10 +1258,6 @@ def one_ae_epoch(model, other_models, quiz_machine, n_epoch, local_device=main_d targets, logits = targets_and_prediction(model, input, mask_generate) - print( - f"{input.device=} {logits.device=} {targets.device=} {logits.device=} {mask_loss.device=}" - ) - loss = NTC_masked_cross_entropy(logits, targets, mask_loss) acc_train_loss += loss.item() * input.size(0) nb_train_samples += input.size(0) @@ -1342,46 +1336,56 @@ def generate_ae_c_quizzes(models, local_device=main_device): quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1) ) - records = [[] for _ in criteria] + duration_max = 600 # 3 * 3600 with torch.autograd.no_grad(): - while min([bag_len(bag) for bag in records]) < 128: + records = [[] for _ in criteria] + + start_time = time.perf_counter() + + while ( + time.perf_counter() < start_time + duration_max + and min([bag_len(bag) for bag in records]) < 128 + ): bl = [bag_len(bag) for bag in records] log_string(f"bag_len {bl}") model = models[torch.randint(len(models), (1,)).item()] - result = ae_generate(model, template, mask_generate, 0.0) + result = ae_generate(model, template, mask_generate, noise_proba) probas = torch.cat( [model_ae_proba_solutions(model, result)[:, None] for model in models], dim=1, ) + for c, r in zip(criteria, records): q = result[c(probas)] if q.size(0) > 0: r.append(q) - # for f, record in [("prediction", record_d), ("generation", record_nd)]: - # filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png" + for n, u in enumerate(records): + quizzes = torch.cat(u, dim=0)[:128] + filename = f"culture_{n_epoch:04d}_{n:02d}.png" - # result, predicted_parts, correct_parts = bag_to_tensors(record) + # result, predicted_parts, correct_parts = bag_to_tensors(record) - # l = [model_ae_proba_solutions(model, result) for model in models] - # probas = torch.cat([x[:, None] for x in l], dim=1) - # comments = [] + # l = [model_ae_proba_solutions(model, result) for model in models] + # probas = torch.cat([x[:, None] for x in l], dim=1) + # comments = [] + + # for l in probas: + # comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) - # for l in probas: - # comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + filename, + quizzes=result, + # predicted_parts=predicted_parts, + # correct_parts=correct_parts, + # comments=comments, + ) - # quiz_machine.problem.save_quizzes_as_image( - # args.result_dir, - # filename, - # quizzes=result, - # predicted_parts=predicted_parts, - # correct_parts=correct_parts, - # comments=comments, - # ) - # log_string(f"wrote {filename}") + log_string(f"wrote {filename}") ###################################################################### @@ -1449,6 +1453,9 @@ for n_epoch in range(current_epoch, args.nb_epochs): # one_ae_epoch(models[0], models, quiz_machine, n_epoch, main_device) # exit(0) + if min([m.test_accuracy for m in models]) > args.accuracy_to_make_c_quizzes: + generate_ae_c_quizzes(models, local_device=main_device) + ranked_models = sorted(models, key=lambda m: float(m.test_accuracy)) weakest_models = ranked_models[: len(gpus)] @@ -1472,8 +1479,6 @@ for n_epoch in range(current_epoch, args.nb_epochs): for t in threads: t.join() - generate_ae_c_quizzes(models, local_device=main_device) - # -------------------------------------------------------------------- for model in models: -- 2.39.5