From 61f3811cad42aa21ee1fa034144b257abc85ce8c Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 9 Oct 2024 13:16:48 +0200 Subject: [PATCH] Update. --- attae.py | 3 +++ main.py | 4 +++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/attae.py b/attae.py index c386e3a..d971b9d 100755 --- a/attae.py +++ b/attae.py @@ -298,6 +298,7 @@ class Reasoning(nn.Module): nb_blocks=1, attention=vanilla_attention, attention_dropout=0.0, + len_max=1e5, ): super().__init__() @@ -307,6 +308,8 @@ class Reasoning(nn.Module): self.nb_chunks = nb_chunks self.x_star = randw(nb_f_tokens, dim_model) + self.positional_encoding = VaswaniPositionalEncoding(len_max) + self.trunk_joint = create_trunk( dim_model=dim_model, dim_keys=dim_qk, diff --git a/main.py b/main.py index 41862a2..0bb3212 100755 --- a/main.py +++ b/main.py @@ -932,7 +932,7 @@ if args.test == "aebn": model = new_model(0) # f = model.trunk[len(model.trunk) // 2 :] model.trunk = attae.Reasoning( - nb_f_tokens=250, + nb_f_tokens=25, nb_chunks=2, dim_model=args.dim_model, dim_qk=args.dim_keys, @@ -951,6 +951,8 @@ if args.test == "aebn": # model.test_accuracy = 0.0 # model.nb_epochs = 0 + model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + for n_epoch in range(args.nb_epochs): one_complete_epoch( model, -- 2.39.5