Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 22 Aug 2024 18:44:02 +0000 (20:44 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 22 Aug 2024 18:44:02 +0000 (20:44 +0200)
main.py
mygpt.py
quiz_machine.py

diff --git a/main.py b/main.py
index fc480b7..36b369e 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -341,7 +341,7 @@ def run_tests(model, quiz_machine, local_device=main_device):
         nb_test_samples, acc_test_loss = 0, 0.0
         nb_samples_accumulated = 0
 
-        full_input, full_mask_loss = quiz_machine.data_input(
+        full_input, _, full_mask_loss = quiz_machine.data_input(
             args.nb_test_samples, model.test_c_quiz_bags, args.c_quiz_multiplier
         )
         src = zip(
@@ -370,7 +370,7 @@ def run_tests(model, quiz_machine, local_device=main_device):
 
         log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}")
 
-        input, _ = quiz_machine.data_input(
+        input, _, _ = quiz_machine.data_input(
             2000, model.test_c_quiz_bags, args.c_quiz_multiplier
         )
 
@@ -394,7 +394,7 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
     nb_train_samples, acc_train_loss = 0, 0.0
 
-    full_input, full_mask_loss = quiz_machine.data_input(
+    full_input, _, full_mask_loss = quiz_machine.data_input(
         args.nb_train_samples,
         model.train_c_quiz_bags + common_c_quiz_bags,
         args.c_quiz_multiplier,
@@ -635,7 +635,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
 from mygpt import (
     WithResidual,
     CacheWrapper,
-    AddPositionalEncoding,
+    VaswaniPositionalEncoding,
+    TrainablePositionalEncoding,
     QKVAttention,
     BracketedSequence,
 )
@@ -660,7 +661,7 @@ class Thinker(nn.Module):
 
         self.embedding = nn.Sequential(
             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
-            AddPositionalEncoding(len_max),
+            VaswaniPositionalEncoding(len_max),
         )
 
         def trunk(depth):
@@ -743,7 +744,7 @@ class Thinker(nn.Module):
 from mygpt import (
     WithResidual,
     CacheWrapper,
-    AddPositionalEncoding,
+    VaswaniPositionalEncoding,
     QKVAttention,
     BracketedSequence,
 )
@@ -759,7 +760,7 @@ class MyAttentionVAE(nn.Module):
         nb_heads,
         nb_blocks,
         dropout=0.0,
-        len_max=1e5,
+        len_max=1024,
     ):
         super().__init__()
 
@@ -769,7 +770,7 @@ class MyAttentionVAE(nn.Module):
             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
         )
 
-        self.positional_encoding = AddPositionalEncoding(len_max)
+        self.positional_encoding = TrainablePositionalEncoding(dim_model, len_max)
 
         trunk_blocks = []
 
@@ -850,7 +851,7 @@ def test_ae(local_device=main_device):
             (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
         ]
 
-        full_input, full_mask_loss = quiz_machine.data_input(
+        full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input(
             args.nb_train_samples, data_structures=data_structures
         )
 
@@ -871,7 +872,7 @@ def test_ae(local_device=main_device):
                 model.optimizer.zero_grad()
 
             targets = input
-            input = (mask_loss == 0).long() * input
+            input = (mask_generate == 0).long() * input
 
             output = model(mygpt.BracketedSequence(input)).x
             loss = F.cross_entropy(output.transpose(1, 2), targets)
@@ -894,7 +895,9 @@ def test_ae(local_device=main_device):
 
             nb_test_samples, acc_test_loss = 0, 0.0
 
-            full_input, full_mask_loss = quiz_machine.data_input(args.nb_test_samples)
+            full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input(
+                args.nb_test_samples, data_structures=data_structures
+            )
 
             src = zip(
                 full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
@@ -909,7 +912,7 @@ def test_ae(local_device=main_device):
                 input = input.to(local_device)
                 mask_loss = mask_loss.to(local_device)
                 targets = input
-                input = (mask_loss == 0).long() * input
+                input = (mask_generate == 0).long() * input
                 output = model(mygpt.BracketedSequence(input)).x
                 loss = F.cross_entropy(output.transpose(1, 2), targets)
                 acc_test_loss += loss.item() * input.size(0)
@@ -917,11 +920,13 @@ def test_ae(local_device=main_device):
 
             log_string(f"test_loss {n_epoch} model AE {acc_test_loss/nb_test_samples}")
 
-            input, mask_loss = quiz_machine.data_input(128)
+            input, mask_generate, mask_loss = quiz_machine.data_input(
+                128, data_structures=data_structures
+            )
             input = input.to(local_device)
             mask_loss = mask_loss.to(local_device)
             targets = input
-            input = (mask_loss == 0).long() * input
+            input = (mask_generate == 0).long() * input
             logits = model(mygpt.BracketedSequence(input)).x
             dist = torch.distributions.categorical.Categorical(logits=logits)
             result = dist.sample()
index f716fe5..8379a57 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -122,7 +122,7 @@ class WithResidual(nn.Module):
 ##############################
 
 
-class AddPositionalEncoding(nn.Module):
+class VaswaniPositionalEncoding(nn.Module):
     def __init__(self, len_max):
         super().__init__()
         self.len_max = len_max
@@ -153,6 +153,26 @@ class AddPositionalEncoding(nn.Module):
 ##############################
 
 
+class TrainablePositionalEncoding(nn.Module):
+    def __init__(self, dim, len_max):
+        super().__init__()
+        self.len_max = len_max
+        self.pe = nn.Parameter(torch.randn(1, len_max, dim) / math.sqrt(dim))
+
+    def forward(self, bs):
+        if bs.first == 0:
+            self.cache_y = bs.x.new(bs.x.size())
+
+        self.cache_y[:, bs.first : bs.first + bs.nb] = (
+            bs.slice() + self.pe[bs.first : bs.first + bs.nb]
+        )
+
+        return BracketedSequence(self.cache_y, bs.first, bs.nb)
+
+
+##############################
+
+
 class EncoderHead(nn.Module):
     def __init__(self, dim_in, dim_out):
         super().__init__()
@@ -338,7 +358,7 @@ class MyGPT(nn.Module):
             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
         )
 
-        self.positional_encoding = AddPositionalEncoding(len_max)
+        self.positional_encoding = VaswaniPositionalEncoding(len_max)
 
         trunk_blocks = []
 
index 08f121a..bea0d78 100755 (executable)
@@ -178,23 +178,24 @@ class QuizMachine:
             quizzes, structs=[s for s, _, _, _ in data_structures]
         )
 
+        quiz_mask_generate = quizzes.new_full(quizzes.size(), 1)
         quiz_mask_loss = quizzes.new_full(quizzes.size(), 1)
 
-        for struct, _, quad_noise, quad_loss in data_structures:
+        for struct, quad_generate, quad_noise, quad_loss in data_structures:
             i = self.problem.indices_select(quizzes=quizzes, struct=struct)
             if i.any():
                 if self.prompt_noise > 0.0:
                     quizzes[i] = self.problem.inject_noise(
                         quizzes[i], self.prompt_noise, struct=struct, quad=quad_noise
                     )
+                quiz_mask_generate[i] = self.make_quiz_mask(
+                    quizzes=quizzes[i], struct=struct, quad=quad_generate
+                )
                 quiz_mask_loss[i] = self.make_quiz_mask(
                     quizzes=quizzes[i], struct=struct, quad=quad_loss
                 )
 
-        print("quad_loss", quad_loss)
-        print("quiz_mask_loss", quiz_mask_loss)
-
-        return quizzes, quiz_mask_loss
+        return quizzes, quiz_mask_generate, quiz_mask_loss
 
     ######################################################################