Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 8 Sep 2024 09:59:58 +0000 (11:59 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 8 Sep 2024 09:59:58 +0000 (11:59 +0200)
attae.py
main.py

index 3a9f105..7bd4a44 100755 (executable)
--- a/attae.py
+++ b/attae.py
@@ -102,7 +102,7 @@ class AttentionAE(nn.Module):
         assert dim_model % nb_heads == 0
 
         self.embedding = nn.Sequential(
-            nn.Embedding(2 * vocabulary_size, dim_model),
+            nn.Embedding(vocabulary_size, dim_model),
             nn.Dropout(dropout),
         )
 
@@ -166,5 +166,11 @@ if __name__ == "__main__":
     )
 
     x = torch.randint(100, (10, 50))
-
     y = model(x)
+
+    with torch.no_grad():
+        model.eval()
+        x = torch.randint(100, (10, 50))
+        y = model(x)
+
+        print(y)
diff --git a/main.py b/main.py
index a4030ff..d90a3df 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -16,6 +16,8 @@ from torch.nn import functional as F
 
 import ffutils
 
+import attae
+
 import mygpt
 import sky, grids, quiz_machine
 
@@ -373,7 +375,7 @@ def optimizer_to(optim, device):
 
 
 from mygpt import (
-    WithResidual,
+    CachedWithResidual,
     CacheWrapper,
     CachedVaswaniPositionalEncoding,
     QKVAttention,
@@ -394,7 +396,7 @@ class MultiEmbedding(nn.Module):
 
 
 def attention_block(dim_model, dim_keys, nb_heads, dropout):
-    return WithResidual(
+    return CachedWithResidual(
         CacheWrapper(
             nn.LayerNorm((dim_model,)),
         ),
@@ -409,7 +411,7 @@ def attention_block(dim_model, dim_keys, nb_heads, dropout):
 
 
 def ffw_block(dim_model, dim_hidden, nb_heads, dropout):
-    return WithResidual(
+    return CachedWithResidual(
         CacheWrapper(
             nn.LayerNorm((dim_model,)),
             nn.Linear(in_features=dim_model, out_features=dim_hidden),
@@ -438,7 +440,8 @@ class MyAttentionAE(nn.Module):
 
         self.embedding = CacheWrapper(
             nn.Sequential(
-                MultiEmbedding((vocabulary_size, 2), dim_model), nn.Dropout(dropout)
+                MultiEmbedding((vocabulary_size, 2), dim_model),
+                nn.Dropout(dropout),
             ),
         )
 
@@ -997,7 +1000,7 @@ models = []
 
 for i in range(args.nb_models):
     model = MyAttentionAE(
-        # model = FunctionalAE(
+        # model = attae.AttentionAE(
         vocabulary_size=vocabulary_size,
         dim_model=args.dim_model,
         dim_keys=args.dim_keys,