From e9ecef42ebc80b640a5530a29e9a845d86761644 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 10 Oct 2024 10:12:11 +0200 Subject: [PATCH] Update. --- attae.py | 21 +++++++++++++++++++++ main.py | 19 +++++++++++++------ 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/attae.py b/attae.py index 3eb6c4e..0d36a33 100755 --- a/attae.py +++ b/attae.py @@ -54,6 +54,27 @@ class BlockRandomPositionalEncoding(nn.Module): ###################################################################### +class AdHocPositionalEncoding(nn.Module): + def __init__(self, dim_model, value, trainable=False): + super().__init__() + if trainable: + self.value = nn.Parameter(value.clone()) + else: + self.register_buffer("value", value.clone()) + self.fc = nn.Linear( + in_features=value.size(-1) + dim_model, out_features=dim_model + ) + + def forward(self, x): + value = self.value[None, :, :].repeat(x.size(0), 1, 1) + x = torch.cat([value, x], dim=2) + y = self.fc(x) + return y + + +###################################################################### + + class WithResidual(nn.Module): def __init__(self, *f): super().__init__() diff --git a/main.py b/main.py index ba5c6e2..d5c1c5c 100755 --- a/main.py +++ b/main.py @@ -589,7 +589,7 @@ def save_inference_images(model, n_epoch, c_quizzes, c_quiz_multiplier, local_de problem.save_quizzes_as_image( args.result_dir, - f"culture_prediction_{n_epoch}_{model.id}.png", + f"culture_prediction_{n_epoch:04d}_{model.id:02d}.png", quizzes=result[:128], predicted_parts=predicted_parts[:128], correct_parts=correct_parts[:128], @@ -601,7 +601,7 @@ def save_inference_images(model, n_epoch, c_quizzes, c_quiz_multiplier, local_de problem.save_quizzes_as_image( args.result_dir, - f"culture_generation_{n_epoch}_{model.id}.png", + f"culture_generation_{n_epoch:04d}_{model.id:02d}.png", quizzes=result[:128], ) @@ -927,12 +927,19 @@ if args.test == "aebn": len_max=1e4, ) - model.positional_encoding = attae.BlockRandomPositionalEncoding( - args.dim_model, 100, 4 - ) + # model.positional_encoding = attae.BlockRandomPositionalEncoding( + # args.dim_model, 100, 4 + # ) + + i = torch.arange(400)[:, None] + k = [2**k for k in range(4)] + [10 * 2**k for k in range(4)] + [100, 200] + k = torch.tensor(k)[None, :] + pe = (i // k) % 2 + + model.positional_encoding = attae.AdHocPositionalEncoding(args.dim_model, pe) model.trunk = attae.Reasoning( - nb_f_tokens=25, + nb_f_tokens=8, nb_chunks=2, dim_model=args.dim_model, dim_qk=args.dim_keys, -- 2.39.5