Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 9 Oct 2024 11:16:48 +0000 (13:16 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 9 Oct 2024 11:16:48 +0000 (13:16 +0200)
attae.py
main.py

index c386e3a..d971b9d 100755 (executable)
--- a/attae.py
+++ b/attae.py
@@ -298,6 +298,7 @@ class Reasoning(nn.Module):
         nb_blocks=1,
         attention=vanilla_attention,
         attention_dropout=0.0,
+        len_max=1e5,
     ):
         super().__init__()
 
@@ -307,6 +308,8 @@ class Reasoning(nn.Module):
         self.nb_chunks = nb_chunks
         self.x_star = randw(nb_f_tokens, dim_model)
 
+        self.positional_encoding = VaswaniPositionalEncoding(len_max)
+
         self.trunk_joint = create_trunk(
             dim_model=dim_model,
             dim_keys=dim_qk,
diff --git a/main.py b/main.py
index 41862a2..0bb3212 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -932,7 +932,7 @@ if args.test == "aebn":
     model = new_model(0)
     # f = model.trunk[len(model.trunk) // 2 :]
     model.trunk = attae.Reasoning(
-        nb_f_tokens=250,
+        nb_f_tokens=25,
         nb_chunks=2,
         dim_model=args.dim_model,
         dim_qk=args.dim_keys,
@@ -951,6 +951,8 @@ if args.test == "aebn":
     # model.test_accuracy = 0.0
     # model.nb_epochs = 0
 
+    model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+
     for n_epoch in range(args.nb_epochs):
         one_complete_epoch(
             model,