+ acc_test_loss = 0
+ nb_test_errors = 0
+
+ for input, targets in zip(
+ test_input.split(batch_size, dim=1), test_targets.split(batch_size, dim=1)
+ ):
+ h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
+ h = F.relu(h)
+ output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
+ loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1))
+ acc_test_loss += loss.item() * input.size(0)
+
+ wta = output.argmax(-1)
+ nb_test_errors += (wta != targets).long().sum(-1)
+
+ test_error = nb_test_errors / test_input.size(1)