Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 27 Jul 2024 03:38:31 +0000 (05:38 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 27 Jul 2024 03:38:31 +0000 (05:38 +0200)
main.py
mygpt.py

diff --git a/main.py b/main.py
index 3787e9f..848ac9c 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -645,7 +645,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
 
 ######################################################################
-#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
 
 
 def train_auto_encoder():
@@ -658,9 +657,10 @@ def train_auto_encoder():
         nb_blocks=args.nb_blocks,
         causal=False,
         dropout=args.dropout,
-        auto_encoder_dim=64,
     ).to(main_device)
 
+    model.make_auto_encoder(auto_encoder_dim=64)
+
     test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples)
 
     optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
index 9bec09e..b38cc99 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -256,7 +256,6 @@ class MyGPT(nn.Module):
         nb_blocks,
         causal=False,
         dropout=0.0,
-        auto_encoder_dim=-1,
         len_max=1e5,
     ):
         super().__init__()
@@ -298,24 +297,6 @@ class MyGPT(nn.Module):
                 ),
             ]
 
-        if auto_encoder_dim > 0:
-            self.encoder = nn.Sequential(
-                *(
-                    trunk_blocks[: nb_blocks // 2]
-                    + [EncoderHead(dim_model, auto_encoder_dim)]
-                )
-            )
-
-            self.decoder = nn.Sequential(
-                *(
-                    [
-                        DecoderBottom(auto_encoder_dim, dim_model),
-                        AddPositionalEncoding(len_max),
-                    ]
-                    + trunk_blocks[nb_blocks // 2 :]
-                )
-            )
-
         self.trunk = nn.Sequential(*trunk_blocks)
 
         self.readout = CacheWrapper(
@@ -337,6 +318,24 @@ class MyGPT(nn.Module):
         bs = self.readout(bs)
         return bs
 
+    def make_auto_encoder(self, auto_encoder_dim):
+        self.encoder = nn.Sequential(
+            *(
+                trunk_blocks[: nb_blocks // 2]
+                + [EncoderHead(dim_model, auto_encoder_dim)]
+            )
+        )
+
+        self.decoder = nn.Sequential(
+            *(
+                [
+                    DecoderBottom(auto_encoder_dim, dim_model),
+                    AddPositionalEncoding(len_max),
+                ]
+                + trunk_blocks[nb_blocks // 2 :]
+            )
+        )
+
     def encode(self, bs):
         bs = self.embedding(bs)
         z = self.encoder(bs)