From 3cac1c28149834f30a046693fc5f9c01f1369da4 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 19 Sep 2024 08:22:26 +0200 Subject: [PATCH] Update. --- attae.py | 30 ++++++++++++------------------ main.py | 31 ++++++++++--------------------- 2 files changed, 22 insertions(+), 39 deletions(-) diff --git a/attae.py b/attae.py index b4db3ab..a9bdeba 100755 --- a/attae.py +++ b/attae.py @@ -127,6 +127,7 @@ class AttentionAE(nn.Module): dim_qk=dim_keys, dim_v=dim_model // nb_heads, nb_heads=nb_heads, + attention=attention, attention_dropout=dropout, ), ), @@ -170,28 +171,20 @@ class FunctionalAttentionAE(AttentionAE): dim_keys, dim_hidden, nb_heads, - nb_work_tokens, nb_blocks, + nb_work_tokens=100, dropout=0.0, len_max=1e5, ): - # def functional_mask(b, h, q_idx, kv_idx): - # return ( - # (q_idx < nb_work_tokens) - # | (kv_idx < nb_work_tokens) - # | ((q_idx - nb_work_tokens) // 200 == (kv_idx - nb_work_tokens) // 200) - # ) - - # block_mask = create_block_mask( - # functional_mask, - # B=None, - # H=None, - # Q_LEN=400 + nb_work_tokens, - # KV_LEN=400 + nb_work_tokens, - # ) - - # def functional_attention(q, k, v): - # return flex_attention(q, k, v, block_mask=block_mask) + def no_peek_attention(q, k, v): + a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3)) + n = self.nb_work_tokens + s = (q.size(2) - n) // 2 + a[:, :, n + 0 * s : n + 1 * s, n + 0 * s : n + 1 * s] = float("-inf") + a[:, :, n + 1 * s : n + 2 * s, n + 1 * s : n + 2 * s] = float("-inf") + a = a.softmax(dim=3) + y = torch.einsum("nhts,nhsd->nhtd", a, v) + return y AttentionAE.__init__( self, @@ -201,6 +194,7 @@ class FunctionalAttentionAE(AttentionAE): dim_hidden, nb_heads, nb_blocks, + attention=no_peek_attention, dropout=0.0, len_max=1e5, ) diff --git a/main.py b/main.py index 16edcdc..d903693 100755 --- a/main.py +++ b/main.py @@ -198,6 +198,8 @@ if args.seed >= 0: def log_string(s): + """print the given string prefixed with a time stamps, and log it into log_file is not None""" + t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime()) if log_file is not None: @@ -352,13 +354,10 @@ def add_noise_imt(imt_set): ###################################################################### - -# IMT for input / masks / target - -# Generate a batch for prediction +# Prediction -def batch_for_prediction_imt(input): +def samples_for_prediction_imt(input): nb = input.size(0) masks = input.new_zeros(input.size()) u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4) @@ -421,7 +420,7 @@ def predict_full(model, input, with_perturbations=False, local_device=main_devic ###################################################################### -def batch_for_generation_imt(input): +def samples_for_generation_imt(input): nb = input.size(0) probs_iterations = 0.1 ** torch.linspace( 0, 1, args.diffusion_nb_iterations, device=input.device @@ -511,13 +510,13 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): # Half of the samples train the prediction, and we inject noise in # all, and hints in half - b_p = batch_for_prediction_imt(q_p) + b_p = samples_for_prediction_imt(q_p) b_p = add_noise_imt(b_p) half = torch.rand(b_p.size(0)) < 0.5 b_p[half] = add_hints_imt(b_p[half]) # The other half are denoising examples for the generation - b_g = batch_for_generation_imt(q_g) + b_g = samples_for_generation_imt(q_g) imt_set = torch.cat([b_p, b_g]) imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)] @@ -590,7 +589,7 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): quizzes = quiz_machine.quiz_set( args.nb_test_samples, c_quizzes, args.c_quiz_multiplier ) - imt_set = batch_for_prediction_imt(quizzes.to(local_device)) + imt_set = samples_for_prediction_imt(quizzes.to(local_device)) result = ae_predict(model, imt_set, local_device=local_device).to("cpu") masks = imt_set[:, 1].to("cpu") @@ -634,18 +633,8 @@ import attae models = [] for i in range(args.nb_models): - # model = attae.FunctionalAttentionAE( - # vocabulary_size=vocabulary_size * 2, - # dim_model=args.dim_model, - # dim_keys=args.dim_keys, - # dim_hidden=args.dim_hidden, - # nb_heads=args.nb_heads, - # nb_blocks=args.nb_blocks, - # nb_work_tokens=10, - # dropout=args.dropout, - # ) - - model = attae.AttentionAE( + model = attae.FunctionalAttentionAE( + # model = attae.AttentionAE( vocabulary_size=vocabulary_size * 2, dim_model=args.dim_model, dim_keys=args.dim_keys, -- 2.39.5