Update.
[picoclvr.git] / mygpt.py
index b4446c6..6a12a5a 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -5,6 +5,11 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
+# This is an implementation from scratch of a "GPT", that is a model
+# composed of several causal self-attention blocks. It is equipped
+# with a caching mechanism for keys and values to avoid a O(N^3) cost
+# for auto-regression.
+
 import math
 
 import torch
@@ -14,19 +19,6 @@ from torch.nn import functional as F
 
 ######################################################################
 
-
-class WithResidual(nn.Module):
-    def __init__(self, *f):
-        super().__init__()
-        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
-
-    def forward(self, bs):
-        bs.x = bs.x + self.f(bs).x
-        return bs
-
-
-######################################################################
-
 # A BracketedSequence is a BxTx... tensor with a first and a nb time
 # steps to compute.
 
@@ -78,6 +70,19 @@ class CacheWrapper(nn.Module):
 ##############################
 
 
+class WithResidual(nn.Module):
+    def __init__(self, *f):
+        super().__init__()
+        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+    def forward(self, bs):
+        bs.x = bs.x + self.f(bs).x
+        return bs
+
+
+##############################
+
+
 class AddPositionalEncoding(nn.Module):
     def __init__(self, len_max):
         super().__init__()