From eb86c22964f03f186ee225f129bc260128b10f9a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 24 Mar 2023 18:34:30 +0100 Subject: [PATCH] Update --- beaver.py | 11 +++++++++++ mygpt.py | 24 +++++++++++++++++++----- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/beaver.py b/beaver.py index 6a6343d..4f41832 100755 --- a/beaver.py +++ b/beaver.py @@ -66,6 +66,8 @@ parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa parser.add_argument("--random_regression_order", action="store_true", default=False) +parser.add_argument("--noncausal_prompt", action="store_true", default=False) + parser.add_argument("--no_checkpoint", action="store_true", default=False) parser.add_argument("--overwrite_results", action="store_true", default=False) @@ -517,6 +519,14 @@ log_string(f"vocabulary_size {vocabulary_size}") ############################## +amm_generator = None + +if args.noncausal_prompt: + amm_generator = lambda d: torch.logical_and( + torch.arange(d)[None, None, :, None] < torch.arange(d)[None, None, None, :], + torch.arange(d)[None, None, :, None] >= d // 2, + ) + model = mygpt.MyGPT( vocabulary_size=vocabulary_size, dim_model=args.dim_model, @@ -526,6 +536,7 @@ model = mygpt.MyGPT( nb_blocks=args.nb_blocks, causal=True, dropout=args.dropout, + amm_generator=amm_generator, ) model.to(device) diff --git a/mygpt.py b/mygpt.py index 75adbf6..7166788 100755 --- a/mygpt.py +++ b/mygpt.py @@ -132,13 +132,28 @@ class AddPositionalEncoding(nn.Module): class QKVAttention(nn.Module): def __init__( - self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0 + self, + dim_in, + dim_qk, + dim_v, + nb_heads=1, + causal=False, + attention_dropout=0.0, + amm_generator=None, ): super().__init__() def randw(*d): return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + if amm_generator is None: + self.amm_generator = ( + lambda d: torch.arange(d)[None, None, :, None] + < torch.arange(d)[None, None, None, :] + ) + else: + self.amm_generator = amm_generator + self.causal = causal self.attention_dropout = attention_dropout @@ -175,10 +190,7 @@ class QKVAttention(nn.Module): if self.causal: if bs_q.first == 0: - self.cache_attzero = ( - torch.arange(x_q.size(1), device=q.device)[None, None, :, None] - < torch.arange(x_q.size(1), device=q.device)[None, None, None, :] - ) + self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device) a = a.masked_fill( self.cache_attzero[ :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb @@ -215,6 +227,7 @@ class MyGPT(nn.Module): causal=False, dropout=0.0, len_max=1e5, + amm_generator=None, ): super().__init__() @@ -238,6 +251,7 @@ class MyGPT(nn.Module): nb_heads=nb_heads, causal=causal, attention_dropout=dropout, + amm_generator=amm_generator, ), ), WithResidual( -- 2.39.5