residual_masker=residual_masker,
)
+ self.mha_A = MHAttention(
+ dim_model=dim_model,
+ dim_qk=dim_qk,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ attention=vanilla_attention,
+ attention_dropout=attention_dropout,
+ )
+
+ self.mha_B = MHAttention(
+ dim_model=dim_model,
+ dim_qk=dim_qk,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ attention=vanilla_attention,
+ attention_dropout=attention_dropout,
+ )
+
+ def forward_AB(self, x_q):
+ T, S = x_q.size(1), self.x_star.size(0)
+ nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks
+
+ x = x_q
+ x = x.reshape(nb, nc, T // nc, dim).reshape(nb * nc, T // nc, dim)
+ x = self.trunk_A(x)
+ f = self.x_star.reshape(1, S, dim).expand(nb * nc, S, dim)
+ f = self.mha_A(f, x)
+
+ k = torch.arange(nb, device=x_q.device)
+ u = f[k * 2, :]
+ f[k * 2, :] = f[k * 2 + 1, :]
+ f[k * 2 + 1, :] = u
+
+ f = self.mha_B(x, f)
+ x = self.trunk_B(x)
+ x = x.reshape(nb, nc, T // nc, dim).reshape(nb, T, dim)
+
+ return x
+
def forward(self, x_q):
T, S = x_q.size(1), self.x_star.size(0)
nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks
if nb_samples % args.batch_size == 0:
model.optimizer.step()
+ if train:
+ model.nb_epochs += 1
+
log_string(f"{label}_loss {n_epoch} model {model.id} {acc_loss/nb_samples}")
pe, # trainable=True
)
- nb_f_tokens = 100
+ nb_f_tokens = 200
def no_f_residual(x):
m = x.new_full((1, x.size(1), 1), 1.0)
model.test_accuracy = 0.0
model.nb_epochs = 0
- for n_epoch in range(args.nb_epochs):
- one_complete_epoch(
- model,
- n_epoch,
- train_c_quizzes=None,
- test_c_quizzes=None,
- local_device=main_device,
- )
+ if args.resume:
filename = f"aebn_{model.id:03d}.pth"
- torch.save(
- {
- "state_dict": model.state_dict(),
- "optimizer_state_dict": model.optimizer.state_dict(),
- "test_accuracy": model.test_accuracy,
- "nb_epochs": model.nb_epochs,
- },
+
+ d = torch.load(
os.path.join(args.result_dir, filename),
+ map_location="cpu",
+ weights_only=False,
)
+ model.load_state_dict(d["state_dict"])
+ model.optimizer.load_state_dict(d["optimizer_state_dict"])
+ model.test_accuracy = d["test_accuracy"]
+ model.nb_epochs = d["nb_epochs"]
+ log_string(f"successfully loaded {filename} nb_epochs {model.nb_epochs}")
+
+ else:
+ for n_epoch in range(args.nb_epochs):
+ one_complete_epoch(
+ model,
+ n_epoch,
+ train_c_quizzes=None,
+ test_c_quizzes=None,
+ local_device=main_device,
+ )
+ filename = f"aebn_{model.id:03d}.pth"
+ torch.save(
+ {
+ "state_dict": model.state_dict(),
+ "optimizer_state_dict": model.optimizer.state_dict(),
+ "test_accuracy": model.test_accuracy,
+ "nb_epochs": model.nb_epochs,
+ },
+ os.path.join(args.result_dir, filename),
+ )
exit(0)