Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 4 Aug 2024 08:36:25 +0000 (10:36 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 4 Aug 2024 08:36:25 +0000 (10:36 +0200)
main.py
mygpt.py

diff --git a/main.py b/main.py
index 8f3568f..63f6cce 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -372,7 +372,8 @@ def run_tests(model, quiz_machine, local_device=main_device):
 
         for input in tqdm.tqdm(src, dynamic_ncols=True, desc="test"):
             input = input.to(local_device)
-            output = model(mygpt.BracketedSequence(input)).x
+            sigma = torch.rand(input.size(), device=input.device).sort(dim=1).indices
+            output = model(mygpt.BracketedSequence(input), sigma).x
             loss = F.cross_entropy(output.transpose(1, 2), input)
             acc_test_loss += loss.item() * input.size(0)
             nb_test_samples += input.size(0)
@@ -417,7 +418,8 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
         targets = input
 
-        output = model(mygpt.BracketedSequence(input)).x
+        sigma = torch.rand(input.size(), device=input.device).sort(dim=1).indices
+        output = model(mygpt.BracketedSequence(input), sigma).x
         loss_per_token = F.cross_entropy(
             output.transpose(1, 2), targets, reduction="none"
         )
index 15ed80e..b1cdf4d 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -90,22 +90,38 @@ class AddPositionalEncoding(nn.Module):
 
     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
 
-    def forward(self, bs):
+    def forward(self, bs, sigma=None):
         if bs.first == 0:
-            t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
-                :, None
-            ]
-            j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
-                None, :
-            ]
-            k = j % 2
-            self.pe = torch.sin(
-                t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
-            )
+            if sigma is None:
+                t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
+                    None, :, None
+                ]
+                j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
+                    None, None, :
+                ]
+                k = j % 2
+                self.pe = torch.sin(
+                    t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
+                )
+            else:
+                t_out = sigma[:, :, None]
+                t_in = F.pad(t_out, (0, 0, 1, -1), value=-1)
+                j = torch.arange(
+                    bs.x.size(2) // 2, dtype=bs.x.dtype, device=bs.x.device
+                )[None, None, :]
+                k = j % 2
+                pe_out = torch.sin(
+                    t_out / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
+                )
+                pe_in = torch.sin(
+                    t_in / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
+                )
+                self.pe = torch.cat([pe_in, pe_out], dim=2)
+
             self.cache_y = bs.x.new(bs.x.size())
 
         self.cache_y[:, bs.first : bs.first + bs.nb] = (
-            bs.slice() + self.pe[bs.first : bs.first + bs.nb]
+            bs.slice() + self.pe[:, bs.first : bs.first + bs.nb]
         )
 
         return BracketedSequence(self.cache_y, bs.first, bs.nb)
@@ -262,11 +278,12 @@ class MyGPT(nn.Module):
 
         self.temperature = 1.0
 
-        self.embedding = nn.Sequential(
-            CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
-            AddPositionalEncoding(len_max),
+        self.embedding = CacheWrapper(
+            nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)
         )
 
+        self.positional_encoding = AddPositionalEncoding(len_max)
+
         trunk_blocks = []
 
         for b in range(nb_blocks):
@@ -331,12 +348,19 @@ class MyGPT(nn.Module):
                     m.bias.zero_()
                     m.weight.fill_(1.0)
 
-    def forward(self, bs):
+    def forward(self, bs, sigma=None):
+        if sigma is not None:
+            bs.x = bs.x.gather(dim=1, index=sigma)
         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
         bs = self.embedding(bs)
+        bs = self.positional_encoding(bs, sigma)
         bs = self.trunk(bs)
         bs = self.readout(bs)
         bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature
+        if sigma is not None:
+            bs.x.scatter_(
+                dim=1, index=sigma[:, :, None].expand_as(bs.x), src=bs.x.clone()
+            )
         return bs
 
     def encode(self, bs):
@@ -351,7 +375,6 @@ class MyGPT(nn.Module):
 
     def partial_forward(self, bs, start_layer=None, end_layer=None):
         if start_layer is None:
-            # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
             bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
             bs = self.embedding(bs)
             if end_layer is not None: