+ t = gpt.training
+ gpt.eval()
+ model = nn.Linear(args.dim_model, 4).to(device)
+
+ for n_epoch in range(args.nb_epochs):
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
+
+ acc_train_loss, nb_train_samples = 0, 0
+ for input, targets in task.policy_batches(split="train"):
+ output_gpt = gpt(mygpt.BracketedSequence(input), with_readout=False).x
+ output = model(output_gpt)
+ loss = (
+ -(output.log_softmax(-1) * targets).sum(-1).mean()
+ + targets.xlogy(targets).sum(-1).mean()
+ )
+ acc_train_loss += loss.item() * input.size(0)
+ nb_train_samples += input.size(0)
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ acc_test_loss, nb_test_samples = 0, 0
+ for input, targets in task.policy_batches(split="test"):
+ output_gpt = gpt(mygpt.BracketedSequence(input), with_readout=False).x
+ output = model(output_gpt)
+ loss = (
+ -(output.log_softmax(-1) * targets).sum(-1).mean()
+ + targets.xlogy(targets).sum(-1).mean()
+ )
+ acc_test_loss += loss.item() * input.size(0)
+ nb_test_samples += input.size(0)
+
+ log_string(
+ f"diff_ce {n_epoch} train {acc_train_loss/nb_train_samples} test {acc_test_loss/nb_test_samples}"
+ )
+
+ gpt.train(t)