From d12981e75482b80a73f809259ea62150754dd53f Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 8 Sep 2024 12:31:09 +0200 Subject: [PATCH] Update. --- attae.py | 5 +++-- main.py | 7 +++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/attae.py b/attae.py index 7bd4a44..e9e4bff 100755 --- a/attae.py +++ b/attae.py @@ -102,7 +102,7 @@ class AttentionAE(nn.Module): assert dim_model % nb_heads == 0 self.embedding = nn.Sequential( - nn.Embedding(vocabulary_size, dim_model), + nn.Embedding(2 * vocabulary_size, dim_model), nn.Dropout(dropout), ) @@ -143,7 +143,8 @@ class AttentionAE(nn.Module): m.bias.zero_() m.weight.fill_(1.0) - def forward(self, x, mask=None): + def forward(self, x): + x = 2 * x[:, :, 0] + x[:, :, 1] x = self.embedding(x) x = self.positional_encoding(x) x = self.trunk(x) diff --git a/main.py b/main.py index d90a3df..9285337 100755 --- a/main.py +++ b/main.py @@ -999,8 +999,8 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi models = [] for i in range(args.nb_models): - model = MyAttentionAE( - # model = attae.AttentionAE( + # model = MyAttentionAE( + model = attae.AttentionAE( vocabulary_size=vocabulary_size, dim_model=args.dim_model, dim_keys=args.dim_keys, @@ -1338,6 +1338,9 @@ for n_epoch in range(current_epoch, args.nb_epochs): else: log_string(f"nb_c_quizzes {c_quizzes.size(0)}") + # one_ae_epoch(model, quiz_machine, n_epoch, None) + # exit(0) + # -------------------------------------------------------------------- ranked_models = sorted(models, key=lambda m: float(m.test_accuracy)) -- 2.39.5