t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None]
j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :]
k = j % 2 # works with float, weird
- pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k)
+ period = self.len_max ** ((j - k) / x.size(2))
+ pe = torch.sin(t / period + (math.pi / 2) * k)
y = x + pe
return y
def forward(self, x_q):
T, S = x_q.size(1), self.x_star.size(0)
- nb, dim = x_q.size(0), x_q.size(2)
- x_star = self.x_star.unsqueeze(0).expand(nb, -1, -1)
+ nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks
+
+ x_star = self.x_star.reshape(1, S, dim).expand(nb, S, dim)
x = torch.cat([x_star, x_q], dim=1)
x = self.trunk_joint(x)
- f, x = x[:, : x_star.size(1), :], x[:, x_star.size(1) :, :]
+ f, x = x[:, :S, :], x[:, S:, :]
if hasattr(self, "forced_f") and self.forced_f is not None:
f = self.forced_f.clone()
self.pred_f = f.clone()
- x = x.reshape(nb, self.nb_chunks, T // self.nb_chunks, dim)
- f = f.unsqueeze(1).expand(nb, self.nb_chunks, S, dim)
- x = x.reshape(nb * self.nb_chunks, T // self.nb_chunks, dim)
- f = f.reshape(nb * self.nb_chunks, S, dim)
+ x = x.reshape(nb * nc, T // nc, dim)
+ f = f.repeat(nc, 1, 1)
x = torch.cat([f, x], dim=1)
x = self.trunk_marginal(x)
- x = x[:, x_star.size(1) :, :]
+ x = x[:, S:, :]
x = x.reshape(nb, T, dim)
return x
######################################################################
if args.test == "aebn":
- model = new_model(0)
- # f = model.trunk[len(model.trunk) // 2 :]
+ model = attae.AttentionAE(
+ vocabulary_size_in=vocabulary_size * 2,
+ vocabulary_size_out=vocabulary_size,
+ dim_model=args.dim_model,
+ dim_keys=args.dim_keys,
+ dim_hidden=args.dim_hidden,
+ nb_heads=args.nb_heads,
+ nb_blocks=args.nb_blocks,
+ dropout=args.dropout,
+ len_max=1e4,
+ )
+
+ model.id = 0
+ model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+ model.test_accuracy = 0.0
+ model.nb_epochs = 0
+
model.trunk = attae.Reasoning(
nb_f_tokens=25,
nb_chunks=2,