From 0d3bda1a286803536e7fb3dce4e1ff7c7a9de942 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 27 Jul 2024 05:38:31 +0200 Subject: [PATCH] Update. --- main.py | 4 ++-- mygpt.py | 37 ++++++++++++++++++------------------- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/main.py b/main.py index 3787e9f..848ac9c 100755 --- a/main.py +++ b/main.py @@ -645,7 +645,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 ###################################################################### -#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! def train_auto_encoder(): @@ -658,9 +657,10 @@ def train_auto_encoder(): nb_blocks=args.nb_blocks, causal=False, dropout=args.dropout, - auto_encoder_dim=64, ).to(main_device) + model.make_auto_encoder(auto_encoder_dim=64) + test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples) optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) diff --git a/mygpt.py b/mygpt.py index 9bec09e..b38cc99 100755 --- a/mygpt.py +++ b/mygpt.py @@ -256,7 +256,6 @@ class MyGPT(nn.Module): nb_blocks, causal=False, dropout=0.0, - auto_encoder_dim=-1, len_max=1e5, ): super().__init__() @@ -298,24 +297,6 @@ class MyGPT(nn.Module): ), ] - if auto_encoder_dim > 0: - self.encoder = nn.Sequential( - *( - trunk_blocks[: nb_blocks // 2] - + [EncoderHead(dim_model, auto_encoder_dim)] - ) - ) - - self.decoder = nn.Sequential( - *( - [ - DecoderBottom(auto_encoder_dim, dim_model), - AddPositionalEncoding(len_max), - ] - + trunk_blocks[nb_blocks // 2 :] - ) - ) - self.trunk = nn.Sequential(*trunk_blocks) self.readout = CacheWrapper( @@ -337,6 +318,24 @@ class MyGPT(nn.Module): bs = self.readout(bs) return bs + def make_auto_encoder(self, auto_encoder_dim): + self.encoder = nn.Sequential( + *( + trunk_blocks[: nb_blocks // 2] + + [EncoderHead(dim_model, auto_encoder_dim)] + ) + ) + + self.decoder = nn.Sequential( + *( + [ + DecoderBottom(auto_encoder_dim, dim_model), + AddPositionalEncoding(len_max), + ] + + trunk_blocks[nb_blocks // 2 :] + ) + ) + def encode(self, bs): bs = self.embedding(bs) z = self.encoder(bs) -- 2.20.1