- for input, targets in task.policy_batches():
- output = gpt(mygpt.BracketedSequence(input), with_readout = False).x
+ model = nn.Linear(args.dim_model, 4).to(device)
+
+ for n_epoch in range(args.nb_epochs):
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
+
+ acc_train_loss, nb_train_samples = 0, 0
+ for input, targets in task.policy_batches(split="train"):
+ output_gpt = gpt(mygpt.BracketedSequence(input), with_readout=False).x
+ output = model(output_gpt)
+ loss = -(output.log_softmax(-1) * targets).sum(-1).mean()
+ acc_train_loss += loss.item() * input.size(0)
+ nb_train_samples += input.size(0)
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ acc_test_loss, nb_test_samples = 0, 0
+ for input, targets in task.policy_batches(split="test"):
+ output_gpt = gpt(mygpt.BracketedSequence(input), with_readout=False).x
+ output = model(output_gpt)
+ loss = -(output.log_softmax(-1) * targets).sum(-1).mean()
+ acc_test_loss += loss.item() * input.size(0)
+ nb_test_samples += input.size(0)
+
+ print(
+ f"{n_epoch=} {acc_train_loss/nb_train_samples=} {acc_test_loss/nb_test_samples=}"
+ )
+