From 3c72ba5e16eba4e19167a5d680f2de9f4ecf00d8 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 18 Sep 2024 23:01:03 +0200 Subject: [PATCH] Update. --- attae.py | 44 ++++++++++++++++++++++++++++++++++---------- grids.py | 25 ++++++++++++++++++++----- main.py | 21 ++++++++++++++++++--- 3 files changed, 72 insertions(+), 18 deletions(-) diff --git a/attae.py b/attae.py index 06deed2..b4db3ab 100755 --- a/attae.py +++ b/attae.py @@ -10,7 +10,7 @@ import torch from torch import nn from torch.nn import functional as F -# from torch.nn.attention.flex_attention import flex_attention +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask ###################################################################### @@ -44,7 +44,7 @@ class WithResidual(nn.Module): ###################################################################### -def attention(q, k, v): +def vanilla_attention(q, k, v): a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3)) a = a.softmax(dim=3) y = torch.einsum("nhts,nhsd->nhtd", a, v) @@ -61,6 +61,7 @@ class MHAttention(nn.Module): dim_qk, dim_v, nb_heads=1, + attention=vanilla_attention, attention_dropout=0.0, ): super().__init__() @@ -68,6 +69,7 @@ class MHAttention(nn.Module): def randw(*d): return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + self.attention = attention self.attention_dropout = attention_dropout self.w_q = randw(nb_heads, dim_qk, dim_model) self.w_k = randw(nb_heads, dim_qk, dim_model) @@ -81,7 +83,7 @@ class MHAttention(nn.Module): q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q) k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k) v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v) - y = attention(q, k, v) + y = self.attention(q, k, v) y = torch.einsum("nhtd,hdc->ntc", y, self.w_o) return y @@ -99,6 +101,7 @@ class AttentionAE(nn.Module): dim_hidden, nb_heads, nb_blocks, + attention=vanilla_attention, dropout=0.0, len_max=1e5, ): @@ -159,7 +162,7 @@ class AttentionAE(nn.Module): ###################################################################### -class MaskedAttentionAE(nn.Module): +class FunctionalAttentionAE(AttentionAE): def __init__( self, vocabulary_size, @@ -167,13 +170,32 @@ class MaskedAttentionAE(nn.Module): dim_keys, dim_hidden, nb_heads, + nb_work_tokens, nb_blocks, dropout=0.0, len_max=1e5, ): - super().__init__() - self.core = AttentionAE( - vocabulary_size * 2, + # def functional_mask(b, h, q_idx, kv_idx): + # return ( + # (q_idx < nb_work_tokens) + # | (kv_idx < nb_work_tokens) + # | ((q_idx - nb_work_tokens) // 200 == (kv_idx - nb_work_tokens) // 200) + # ) + + # block_mask = create_block_mask( + # functional_mask, + # B=None, + # H=None, + # Q_LEN=400 + nb_work_tokens, + # KV_LEN=400 + nb_work_tokens, + # ) + + # def functional_attention(q, k, v): + # return flex_attention(q, k, v, block_mask=block_mask) + + AttentionAE.__init__( + self, + vocabulary_size, dim_model, dim_keys, dim_hidden, @@ -182,22 +204,24 @@ class MaskedAttentionAE(nn.Module): dropout=0.0, len_max=1e5, ) + self.nb_work_tokens = nb_work_tokens def forward(self, x): - x = x[:, :, 0] * 2 + x[:, :, 1] - return self.core(x) + x = torch.cat([x.new_zeros(x.size(0), self.nb_work_tokens), x], dim=1) + return AttentionAE.forward(self, x)[:, self.nb_work_tokens :] ###################################################################### if __name__ == "__main__": - model = AttentionAE( + model = FunctionalAttentionAE( vocabulary_size=100, dim_model=16, dim_keys=64, dim_hidden=32, nb_heads=4, + nb_work_tokens=10, nb_blocks=4, dropout=0.1, ) diff --git a/grids.py b/grids.py index 6b2ea23..23a3d12 100755 --- a/grids.py +++ b/grids.py @@ -134,8 +134,16 @@ def grow_islands(nb, height, width, nb_seeds, nb_iterations): class Grids(problem.Problem): + # grid_gray=64 + # thickness=1 + # background_gray=255 + + grid_gray = 255 + thickness = 0 + background_gray = grid_gray + named_colors = [ - ("white", [255, 255, 255]), + ("white", [background_gray, background_gray, background_gray]), # ("white", [224, 224, 224]), ("red", [255, 0, 0]), ("green", [0, 192, 0]), @@ -380,8 +388,9 @@ class Grids(problem.Problem): y = y.reshape(s[0], s[1], s[2] * scale, s[3] * scale) if grids: - y[:, :, :, torch.arange(0, y.size(3), scale)] = 64 - y[:, :, torch.arange(0, y.size(2), scale), :] = 64 + for t in range(self.thickness): + y[:, :, :, torch.arange(t, y.size(3), scale)] = self.grid_gray + y[:, :, torch.arange(t, y.size(2), scale), :] = self.grid_gray for n in range(m.size(0)): for i in range(m.size(1)): @@ -463,11 +472,17 @@ class Grids(problem.Problem): ) frame, white, gray, green, red = torch.tensor( - [[64, 64, 64], [255, 255, 255], [200, 200, 200], [0, 255, 0], [255, 0, 0]], + [ + [self.grid_gray, self.grid_gray, self.grid_gray], + [255, 255, 255], + [200, 200, 200], + [0, 255, 0], + [255, 0, 0], + ], device=quizzes.device, ) - thickness = 1 if grids else 0 + thickness = self.thickness if delta: u = (A != f_A).long() diff --git a/main.py b/main.py index c7131c3..16edcdc 100755 --- a/main.py +++ b/main.py @@ -512,9 +512,9 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): # Half of the samples train the prediction, and we inject noise in # all, and hints in half b_p = batch_for_prediction_imt(q_p) - i = torch.rand(b_p.size(0)) < 0.5 b_p = add_noise_imt(b_p) - b_p[i] = add_hints_imt(b_p[i]) + half = torch.rand(b_p.size(0)) < 0.5 + b_p[half] = add_hints_imt(b_p[half]) # The other half are denoising examples for the generation b_g = batch_for_generation_imt(q_g) @@ -610,7 +610,7 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): # Compute the test accuracy - nb_correct, nb_total = correct.sum(), quizzes.size(0) + nb_correct, nb_total = correct.sum().item(), quizzes.size(0) model.test_accuracy = nb_correct / nb_total log_string( @@ -634,6 +634,17 @@ import attae models = [] for i in range(args.nb_models): + # model = attae.FunctionalAttentionAE( + # vocabulary_size=vocabulary_size * 2, + # dim_model=args.dim_model, + # dim_keys=args.dim_keys, + # dim_hidden=args.dim_hidden, + # nb_heads=args.nb_heads, + # nb_blocks=args.nb_blocks, + # nb_work_tokens=10, + # dropout=args.dropout, + # ) + model = attae.AttentionAE( vocabulary_size=vocabulary_size * 2, dim_model=args.dim_model, @@ -975,6 +986,10 @@ for n_epoch in range(current_epoch, args.nb_epochs): ranked_models = sorted(models, key=lambda m: float(m.test_accuracy)) weakest_models = ranked_models[: len(gpus)] + log_string( + f"weakest_accuracies {[model.test_accuracy for model in weakest_models]}" + ) + multithread_execution( one_complete_epoch, [(model, n_epoch, c_quizzes, gpu) for model, gpu in zip(weakest_models, gpus)], -- 2.39.5