From 3cd3db9ee1cb1b462f375b482e6915769d43a73d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 13 Sep 2024 11:36:26 +0200 Subject: [PATCH] Update. --- attae.py | 6 ++-- main.py | 88 ++------------------------------------------------------ 2 files changed, 6 insertions(+), 88 deletions(-) diff --git a/attae.py b/attae.py index e201f60..05084ba 100755 --- a/attae.py +++ b/attae.py @@ -6,7 +6,8 @@ import torch from torch import nn from torch.nn import functional as F -from torch.nn.attention.flex_attention import flex_attention + +# from torch.nn.attention.flex_attention import flex_attention ###################################################################### @@ -105,8 +106,7 @@ class AttentionAE(nn.Module): assert dim_model % nb_heads == 0 self.embedding = nn.Sequential( - nn.Embedding(2 * vocabulary_size, dim_model), - nn.Dropout(dropout), + nn.Embedding(2 * vocabulary_size, dim_model), nn.Dropout(dropout) ) self.positional_encoding = VaswaniPositionalEncoding(len_max) diff --git a/main.py b/main.py index e090f86..92a34f1 100755 --- a/main.py +++ b/main.py @@ -729,7 +729,7 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise def ae_generate(model, x_0, mask_generate, nb_iterations_max=50, mask_hints=None): noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device) - one_iteration_prediction = deterministic(mask_generate)[:, None] + single_iteration = deterministic(mask_generate)[:, None] if mask_hints is not None: mask_generate = mask_generate * (1 - mask_hints) @@ -746,12 +746,11 @@ def ae_generate(model, x_0, mask_generate, nb_iterations_max=50, mask_hints=None hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample() - hat_x_t_minus_1 = one_iteration_prediction * hat_x_0 + ( - 1 - one_iteration_prediction + hat_x_t_minus_1 = single_iteration * hat_x_0 + ( + 1 - single_iteration ) * sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t) if hat_x_t_minus_1.equal(x_t): - # log_string(f"exit after {it+1} iterations") break else: changed = changed & (hat_x_t_minus_1 != x_t).max(dim=1).values @@ -794,54 +793,6 @@ def model_ae_proba_solutions(model, input, log_probas=False, reduce=True): return (-loss).exp() -def model_ae_argmax_nb_mistakes(model, input): - record = [] - - for x_0 in input.split(args.batch_size): - nb_mistakes = 0 - for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]: - mask_generate = quiz_machine.make_quiz_mask( - quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad - ) - - logits = logits_hat_x_0_from_random_iteration( - model, x_0, mask_generate, prompt_noise=args.prompt_noise - ) - - predicted = logits.argmax(dim=-1) - - nb_mistakes = nb_mistakes + ( - mask_generate * predicted != mask_generate * x_0 - ).long().sum(dim=1) - - record.append(nb_mistakes) - - return torch.cat(record, dim=0) - - -###################################################################### - - -def model_ae_argmax_predictions(model, input): - result = input.clone() - # result[...] = 0 - - for r, x_0 in zip(result.split(args.batch_size), input.split(args.batch_size)): - for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]: - mask_generate = quiz_machine.make_quiz_mask( - quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad - ) - logits = logits_hat_x_0_from_random_iteration( - model, x_0, mask_generate, prompt_noise=args.prompt_noise - ) - - hat_x_0 = logits.argmax(dim=-1) - - r[...] = (1 - mask_generate) * r + mask_generate * hat_x_0 - - return result - - ###################################################################### @@ -1013,39 +964,6 @@ for i in range(args.nb_models): ###################################################################### -def save_badness_statistics( - n_epoch, models, c_quizzes, suffix=None, local_device=main_device -): - for model in models: - model.eval().to(local_device) - c_quizzes = c_quizzes.to(local_device) - with torch.autograd.no_grad(): - log_probas = sum( - [model_ae_proba_solutions(model, c_quizzes) for model in models] - ) - i = log_probas.sort().indices - - suffix = "" if suffix is None else "_" + suffix - - filename = f"culture_badness_{n_epoch:04d}{suffix}.png" - - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - filename, - quizzes=c_quizzes[i[:128]], - # predicted_parts=predicted_parts, - # correct_parts=correct_parts, - # comments=comments, - delta=True, - nrow=8, - ) - - log_string(f"wrote {filename}") - - -###################################################################### - - def quiz_validation( models, c_quizzes, -- 2.39.5