class WithResidual(nn.Module):
- def __init__(self, *f):
+ def __init__(self, f, masker=None):
super().__init__()
self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+ self.masker = masker
def forward(self, x):
- return x + self.f(x)
+ if self.masker is None:
+ mask = 1
+ else:
+ mask = self.masker(x)
+ return mask * x + self.f(x)
######################################################################
######################################################################
-def create_trunk(dim_model, dim_keys, dim_hidden, nb_heads, nb_blocks, dropout=0.0):
+def create_trunk(
+ dim_model,
+ dim_keys,
+ dim_hidden,
+ nb_heads,
+ nb_blocks,
+ dropout=0.0,
+ residual_masker=None,
+):
trunk_blocks = []
for b in range(nb_blocks):
trunk_blocks += [
WithResidual(
- nn.LayerNorm((dim_model,)),
- MHAttention(
- dim_model=dim_model,
- dim_qk=dim_keys,
- dim_v=dim_model // nb_heads,
- nb_heads=nb_heads,
- attention=vanilla_attention,
- attention_dropout=dropout,
+ masker=residual_masker,
+ f=(
+ nn.LayerNorm((dim_model,)),
+ MHAttention(
+ dim_model=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ attention=vanilla_attention,
+ attention_dropout=dropout,
+ ),
),
),
WithResidual(
- nn.LayerNorm((dim_model,)),
- nn.Linear(in_features=dim_model, out_features=dim_hidden),
- nn.ReLU(),
- nn.Linear(in_features=dim_hidden, out_features=dim_model),
- nn.Dropout(dropout),
+ masker=residual_masker,
+ f=(
+ nn.LayerNorm((dim_model,)),
+ nn.Linear(in_features=dim_model, out_features=dim_hidden),
+ nn.ReLU(),
+ nn.Linear(in_features=dim_hidden, out_features=dim_model),
+ nn.Dropout(dropout),
+ ),
),
]
def forward(self, x):
x = self.embedding(x)
- warnings.warn("flipping order for symmetry check", RuntimeWarning)
- x = torch.cat([x[:, 200:], x[:, :200]], dim=1)
+ # warnings.warn("flipping order for symmetry check", RuntimeWarning)
+
x = self.positional_encoding(x)
- x = torch.cat([x[:, 200:], x[:, :200]], dim=1)
x = self.trunk(x)
+
x = self.readout(x)
return x
######################################################################
-class WithMaskedResidual(nn.Module):
- def __init__(self, masker, *f):
- super().__init__()
- self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
- self.masker = masker
- self.mask = None
-
- def forward(self, x):
- if self.mask is None:
- self.mask = self.masker(x)
- return self.mask * x + self.f(x)
-
-
-######################################################################
-
-
class FunctionalAttentionAE(nn.Module):
def __init__(
self,
attention=vanilla_attention,
attention_dropout=0.0,
len_max=1e5,
+ residual_masker=None,
):
super().__init__()
self.positional_encoding = VaswaniPositionalEncoding(len_max)
- self.trunk_joint = create_trunk(
+ self.trunk_A = create_trunk(
dim_model=dim_model,
dim_keys=dim_qk,
dim_hidden=dim_hidden,
nb_heads=nb_heads,
nb_blocks=nb_blocks,
dropout=attention_dropout,
+ residual_masker=residual_masker,
)
- self.trunk_marginal = create_trunk(
+ self.trunk_B = create_trunk(
dim_model=dim_model,
dim_keys=dim_qk,
dim_hidden=dim_hidden,
nb_heads=nb_heads,
nb_blocks=nb_blocks,
dropout=attention_dropout,
+ residual_masker=residual_masker,
)
def forward(self, x_q):
- #!!!!!!!!!!!!!!!!!!!!
- # x_q = torch.cat([x_q[:,200:,:], x_q[:,:200,:]],dim=1)
-
T, S = x_q.size(1), self.x_star.size(0)
nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks
+ f = self.x_star.reshape(1, S, dim).expand(nb, S, dim)
- x_star = self.x_star.reshape(1, S, dim).expand(nb, S, dim)
-
- x = torch.cat([x_star, x_q], dim=1)
- x = self.trunk_joint(x)
-
- f, x = x[:, :S, :], x[:, S:, :]
- x = x.reshape(nb * nc, T // nc, dim)
+ x = x_q
+ x = x.reshape(nb, nc, T // nc, dim).reshape(nb * nc, T // nc, dim)
f = f.repeat(nc, 1, 1)
x = torch.cat([f, x], dim=1)
- x = self.trunk_marginal(x)
+ x = self.trunk_A(x)
+ k = torch.arange(nb, device=x_q.device)
+ u = x[k * 2, :S]
+ x[k * 2, :S] = x[k * 2 + 1, :S]
+ x[k * 2 + 1, :S] = u
+ x = self.trunk_B(x)
+ x = x[:, S:]
+ x = x.reshape(nb, nc, T // nc, dim).reshape(nb, T, dim)
- x = x[:, S:, :]
- x = x.reshape(nb, T, dim)
+ return x
- #!!!!!!!!!!!!!!!!!!!!
- # x = torch.cat([x[:,200:,:], x[:,:200,:]],dim=1)
+ def forward_one_vector(self, x_q):
+ T, S = x_q.size(1), self.x_star.size(0)
+ nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks
+ x_star = self.x_star.reshape(1, S, dim).expand(nb, S, dim)
+ x = torch.cat([x_star, x_q], dim=1)
+ x = self.trunk_A(x)
+ f = x[:, :S, :]
+ x = x_q
+ x = x + 1e-3 * f.mean(dim=1, keepdim=True)
+ x = x.reshape(nb, nc, T // nc, dim).reshape(nb * nc, T // nc, dim)
+ x = self.trunk_B(x)
+ x = x.reshape(nb, nc, T // nc, dim).reshape(nb, T, dim)
return x
imt_set = torch.cat([b_p, b_g])
imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)]
+ batch_size = args.batch_size
+
if train:
label = "train"
model.train().to(local_device)
optimizer_to(model.optimizer, local_device)
- batch_size = args.train_batch_size
else:
label = "test"
model.eval().to(local_device)
- batch_size = args.eval_batch_size
nb_samples, acc_loss = 0, 0.0
# )
i = torch.arange(400)[:, None]
- k = [2**k for k in range(4)] + [10 * 2**k for k in range(4)] + [100, 200]
+ k = [1, 2, 4, 8, 16, 10, 20, 40, 80, 160, 100, 200]
k = torch.tensor(k)[None, :]
- pe = (i // k) % 2
+ pe = 2.0 * ((i // k) % 2) - 1.0
+
+ model.positional_encoding = attae.AdHocPositionalEncoding(
+ args.dim_model,
+ pe, # trainable=True
+ )
+
+ nb_f_tokens = 100
- model.positional_encoding = attae.AdHocPositionalEncoding(args.dim_model, pe)
+ def no_f_residual(x):
+ m = x.new_full((1, x.size(1), 1), 1.0)
+ m[:, :nb_f_tokens, :] = 0
+ return m
model.trunk = attae.Reasoning(
- nb_f_tokens=8,
+ nb_f_tokens=nb_f_tokens,
nb_chunks=2,
dim_model=args.dim_model,
dim_qk=args.dim_keys,
dim_hidden=args.dim_hidden,
nb_heads=args.nb_heads,
- nb_blocks=args.nb_blocks // 2,
+ nb_blocks=args.nb_blocks,
attention_dropout=args.dropout,
)
test_c_quizzes=None,
local_device=main_device,
)
+ filename = f"aebn_{model.id:03d}.pth"
+ torch.save(
+ {
+ "state_dict": model.state_dict(),
+ "optimizer_state_dict": model.optimizer.state_dict(),
+ "test_accuracy": model.test_accuracy,
+ "nb_epochs": model.nb_epochs,
+ },
+ os.path.join(args.result_dir, filename),
+ )
exit(0)