nb_blocks=1,
attention=vanilla_attention,
attention_dropout=0.0,
+ len_max=1e5,
):
super().__init__()
self.nb_chunks = nb_chunks
self.x_star = randw(nb_f_tokens, dim_model)
+ self.positional_encoding = VaswaniPositionalEncoding(len_max)
+
self.trunk_joint = create_trunk(
dim_model=dim_model,
dim_keys=dim_qk,
model = new_model(0)
# f = model.trunk[len(model.trunk) // 2 :]
model.trunk = attae.Reasoning(
- nb_f_tokens=250,
+ nb_f_tokens=25,
nb_chunks=2,
dim_model=args.dim_model,
dim_qk=args.dim_keys,
# model.test_accuracy = 0.0
# model.nb_epochs = 0
+ model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+
for n_epoch in range(args.nb_epochs):
one_complete_epoch(
model,