From: François Fleuret Date: Wed, 9 Oct 2024 11:16:48 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=61f3811cad42aa21ee1fa034144b257abc85ce8c;p=culture.git Update. --- diff --git a/attae.py b/attae.py index c386e3a..d971b9d 100755 --- a/attae.py +++ b/attae.py @@ -298,6 +298,7 @@ class Reasoning(nn.Module): nb_blocks=1, attention=vanilla_attention, attention_dropout=0.0, + len_max=1e5, ): super().__init__() @@ -307,6 +308,8 @@ class Reasoning(nn.Module): self.nb_chunks = nb_chunks self.x_star = randw(nb_f_tokens, dim_model) + self.positional_encoding = VaswaniPositionalEncoding(len_max) + self.trunk_joint = create_trunk( dim_model=dim_model, dim_keys=dim_qk, diff --git a/main.py b/main.py index 41862a2..0bb3212 100755 --- a/main.py +++ b/main.py @@ -932,7 +932,7 @@ if args.test == "aebn": model = new_model(0) # f = model.trunk[len(model.trunk) // 2 :] model.trunk = attae.Reasoning( - nb_f_tokens=250, + nb_f_tokens=25, nb_chunks=2, dim_model=args.dim_model, dim_qk=args.dim_keys, @@ -951,6 +951,8 @@ if args.test == "aebn": # model.test_accuracy = 0.0 # model.nb_epochs = 0 + model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + for n_epoch in range(args.nb_epochs): one_complete_epoch( model,