From bde1469992bbbae012b52373f29409317af3492c Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 9 Oct 2024 17:42:51 +0200 Subject: [PATCH] Update. --- attae.py | 18 +++++++++--------- main.py | 19 +++++++++++++++++-- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/attae.py b/attae.py index d0f4e6b..5d9da2e 100755 --- a/attae.py +++ b/attae.py @@ -24,7 +24,8 @@ class VaswaniPositionalEncoding(nn.Module): t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None] j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :] k = j % 2 # works with float, weird - pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k) + period = self.len_max ** ((j - k) / x.size(2)) + pe = torch.sin(t / period + (math.pi / 2) * k) y = x + pe return y @@ -330,27 +331,26 @@ class Reasoning(nn.Module): def forward(self, x_q): T, S = x_q.size(1), self.x_star.size(0) - nb, dim = x_q.size(0), x_q.size(2) - x_star = self.x_star.unsqueeze(0).expand(nb, -1, -1) + nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks + + x_star = self.x_star.reshape(1, S, dim).expand(nb, S, dim) x = torch.cat([x_star, x_q], dim=1) x = self.trunk_joint(x) - f, x = x[:, : x_star.size(1), :], x[:, x_star.size(1) :, :] + f, x = x[:, :S, :], x[:, S:, :] if hasattr(self, "forced_f") and self.forced_f is not None: f = self.forced_f.clone() self.pred_f = f.clone() - x = x.reshape(nb, self.nb_chunks, T // self.nb_chunks, dim) - f = f.unsqueeze(1).expand(nb, self.nb_chunks, S, dim) - x = x.reshape(nb * self.nb_chunks, T // self.nb_chunks, dim) - f = f.reshape(nb * self.nb_chunks, S, dim) + x = x.reshape(nb * nc, T // nc, dim) + f = f.repeat(nc, 1, 1) x = torch.cat([f, x], dim=1) x = self.trunk_marginal(x) - x = x[:, x_star.size(1) :, :] + x = x[:, S:, :] x = x.reshape(nb, T, dim) return x diff --git a/main.py b/main.py index 0bb3212..618a62e 100755 --- a/main.py +++ b/main.py @@ -929,8 +929,23 @@ log_string(f"vocabulary_size {vocabulary_size}") ###################################################################### if args.test == "aebn": - model = new_model(0) - # f = model.trunk[len(model.trunk) // 2 :] + model = attae.AttentionAE( + vocabulary_size_in=vocabulary_size * 2, + vocabulary_size_out=vocabulary_size, + dim_model=args.dim_model, + dim_keys=args.dim_keys, + dim_hidden=args.dim_hidden, + nb_heads=args.nb_heads, + nb_blocks=args.nb_blocks, + dropout=args.dropout, + len_max=1e4, + ) + + model.id = 0 + model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + model.test_accuracy = 0.0 + model.nb_epochs = 0 + model.trunk = attae.Reasoning( nb_f_tokens=25, nb_chunks=2, -- 2.39.5