######################################################################
+class Reasoning(nn.Module):
+ def __init__(
+ self,
+ nb_f_tokens,
+ nb_chunks,
+ dim_model,
+ dim_qk,
+ dim_hidden,
+ nb_heads=1,
+ nb_blocks=1,
+ attention=vanilla_attention,
+ attention_dropout=0.0,
+ ):
+ super().__init__()
+
+ def randw(*d):
+ return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+ self.nb_chunks = nb_chunks
+ self.x_star = randw(nb_f_tokens, dim_model)
+
+ self.trunk_joint = create_trunk(
+ dim_model=dim_model,
+ dim_keys=dim_qk,
+ dim_hidden=dim_hidden,
+ nb_heads=nb_heads,
+ nb_blocks=nb_blocks,
+ dropout=attention_dropout,
+ )
+
+ self.trunk_marginal = create_trunk(
+ dim_model=dim_model,
+ dim_keys=dim_qk,
+ dim_hidden=dim_hidden,
+ nb_heads=nb_heads,
+ nb_blocks=nb_blocks,
+ dropout=attention_dropout,
+ )
+
+ def forward(self, x_q):
+ T, S = x_q.size(1), self.x_star.size(0)
+ nb, dim = x_q.size(0), x_q.size(2)
+ x_star = self.x_star.unsqueeze(0).expand(nb, -1, -1)
+
+ x = torch.cat([x_star, x_q], dim=1)
+ x = self.trunk_joint(x)
+
+ f, x = x[:, : x_star.size(1), :], x[:, x_star.size(1) :, :]
+ x = x.reshape(nb, self.nb_chunks, T // self.nb_chunks, dim)
+ f = f.unsqueeze(1).expand(nb, self.nb_chunks, S, dim)
+ x = x.reshape(nb * self.nb_chunks, T // self.nb_chunks, dim)
+ f = f.reshape(nb * self.nb_chunks, S, dim)
+ x = torch.cat([f, x], dim=1)
+ x = self.trunk_marginal(x)
+
+ x = x[:, x_star.size(1) :, :]
+ x = x.reshape(nb, T, dim)
+
+ return x
+
+
+######################################################################
+
+
+class LearningToBeMe(nn.Module):
+ def __init__(self, real_f, mimic_f, epsilon):
+ super().__init__()
+ self.real_f = real_f
+ self.mimic_f = mimic_f
+ self.epsilon = epsilon
+
+ def forward(self, x):
+ y_real_f = self.real_f(x)
+ y_mimic_f = self.mimic_f(x.detach())
+ z_real_f = y_real_f.flatten(1)
+ z_mimic_f = y_mimic_f.flatten(1)
+ e = (
+ (z_real_f.detach() - z_mimic_f).pow(2).sum(dim=1, keepdim=True)
+ / z_real_f.detach().pow(2).sum(dim=1, keepdim=True)
+ ).sqrt()
+ self.error = e.mean()
+ m = (e <= self.epsilon).float()
+ self.nb_me = (1 - m).mean()
+ z = (1 - m) * z_real_f + m * z_mimic_f
+ y = z.reshape(y_real_f.shape)
+ return y
+
+
+######################################################################
+
+
if __name__ == "__main__":
model = FunctionalAttentionAE(
vocabulary_size_in=100,
logits.transpose(1, 2), targets, reduction="none"
)
loss = (loss_per_token * masks).mean()
+
+ if args.test == "aebn":
+ error = 0
+ for m in model.modules():
+ if hasattr(m, "error"):
+ error = error + m.error
+ loss = loss + error
+
acc_loss += loss.item() * imt.size(0)
nb_samples += imt.size(0)
if nb_samples % args.batch_size == 0:
model.optimizer.step()
+ if args.test == "aebn":
+ nb_me = []
+ for m in model.modules():
+ if hasattr(m, "nb_me"):
+ nb_me.append(m.nb_me.item())
+
+ log_string(f"{label}_error {n_epoch} model {model.id} {error} nb_me {nb_me}")
+
log_string(f"{label}_loss {n_epoch} model {model.id} {acc_loss/nb_samples}")
######################################################################
if args.test == "aebn":
- model = new_model()
+ model = new_model(0)
+ # f = model.trunk[len(model.trunk) // 2 :]
+ model.trunk = attae.Reasoning(
+ nb_f_tokens=250,
+ nb_chunks=2,
+ dim_model=args.dim_model,
+ dim_qk=args.dim_keys,
+ dim_hidden=args.dim_hidden,
+ nb_heads=args.nb_heads,
+ nb_blocks=args.nb_blocks // 2,
+ attention_dropout=args.dropout,
+ )
- # model.trunk = (
- # model.trunk[: len(model.trunk) // 2] + model.trunk[len(model.trunk) // 2 :]
+ # model.trunk = model.trunk[: len(model.trunk) // 2] + nn.Sequential(
+ # attae.LearningToBeMe(f, g, 1e-1)
# )
- model.id = 0
- model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
- model.test_accuracy = 0.0
- model.nb_epochs = 0
+ # model.id = 0
+ # model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+ # model.test_accuracy = 0.0
+ # model.nb_epochs = 0
for n_epoch in range(args.nb_epochs):
one_complete_epoch(
local_device=main_device,
)
+ exit(0)
+
######################################################################
train_c_quizzes, test_c_quizzes = None, None