From a154f0c2c829aaea947852b91e024a8925e1c875 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 8 Aug 2024 09:27:34 +0200 Subject: [PATCH] Update. --- main.py | 9 +++++++-- mygpt.py | 23 +++++++++-------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/main.py b/main.py index 86eafea..c77a7f3 100755 --- a/main.py +++ b/main.py @@ -990,6 +990,11 @@ def train_complexifier(model_gen, model_pred1, model_pred2): models = [] + +def compute_causal_attzero(t_q, t_k): + return t_q < t_k + + for k in range(args.nb_gpts): log_string(f"creating model {k} and its w_quizzes") @@ -1000,7 +1005,7 @@ for k in range(args.nb_gpts): dim_hidden=args.dim_hidden, nb_heads=args.nb_heads, nb_blocks=args.nb_blocks, - causal=True, + compute_attzero=compute_causal_attzero, dropout=args.dropout, ).to(main_device) @@ -1144,7 +1149,7 @@ if args.test == "generator": dim_hidden=args.dim_hidden, nb_heads=args.nb_heads, nb_blocks=args.nb_blocks, - causal=True, + compute_attzero=compute_causal_attzero, dropout=args.dropout, ).to(main_device) diff --git a/mygpt.py b/mygpt.py index 15ed80e..2706143 100755 --- a/mygpt.py +++ b/mygpt.py @@ -145,7 +145,7 @@ class QKVAttention(nn.Module): dim_qk, dim_v, nb_heads=1, - causal=False, + compute_attzero=None, attention_dropout=0.0, ): super().__init__() @@ -153,7 +153,7 @@ class QKVAttention(nn.Module): def randw(*d): return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) - self.causal = causal + self.compute_attzero = compute_attzero self.attention_dropout = attention_dropout self.record_attention = False @@ -165,10 +165,6 @@ class QKVAttention(nn.Module): def forward(self, bs_q): x_q = bs_q.x - assert ( - self.causal or bs_q.complete() - ), "Partial evaluation is only possible for causal models" - if bs_q.first == 0: self.cache_k = x_q.new_zeros( x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1) @@ -193,12 +189,12 @@ class QKVAttention(nn.Module): "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb] ) / math.sqrt(self.w_q.size(1)) - if self.causal: + if self.compute_attzero is not None: if bs_q.first == 0: - self.cache_attzero = ( - torch.arange(x_q.size(1), device=q.device)[None, None, :, None] - < torch.arange(x_q.size(1), device=q.device)[None, None, None, :] - ) + self.cache_attzero = self.compute_attzero( + torch.arange(x_q.size(1), device=q.device)[:, None], + torch.arange(x_q.size(1), device=q.device)[None, :], + )[None, None, :, :] a = a.masked_fill( self.cache_attzero[ :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb @@ -251,7 +247,7 @@ class MyGPT(nn.Module): dim_hidden, nb_heads, nb_blocks, - causal=False, + compute_attzero=None, autoencoder_dim=-1, dropout=0.0, len_max=1e5, @@ -281,7 +277,7 @@ class MyGPT(nn.Module): dim_qk=dim_keys, dim_v=dim_model // nb_heads, nb_heads=nb_heads, - causal=causal, + compute_attzero=compute_attzero, attention_dropout=dropout, ), ), @@ -407,7 +403,6 @@ if __name__ == "__main__": nb_heads=2, nb_blocks=2, dropout=0.1, - causal=True, ) model.eval() -- 2.20.1