+def compute_perplexity(model, split="train"):
+ with torch.autograd.no_grad():
+ t = model.training
+ model.eval()
+
+ nb_samples, acc_loss = 0, 0.0
+
+ for input in task.batches(split=split):
+ input = input.to(device)
+
+ output = model(mygpt.BracketedSequence(input)).x
+ loss = F.cross_entropy(output.transpose(1, 2), input)
+ acc_loss += loss.item() * input.size(0)
+ nb_samples += input.size(0)
+
+ model.train(t)
+
+ return math.exp(min(100, acc_loss / nb_samples))
+
+
+######################################################################
+
+
+def one_shot(gpt, task):
+ t = gpt.training
+ gpt.eval()
+ model = nn.Linear(args.dim_model, 4).to(device)
+
+ for n_epoch in range(args.nb_epochs):
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
+
+ acc_train_loss, nb_train_samples = 0, 0
+ for input, targets in task.policy_batches(split="train"):
+ output_gpt = gpt(mygpt.BracketedSequence(input), with_readout=False).x
+ output = model(output_gpt)
+ loss = (
+ -(output.log_softmax(-1) * targets).sum(-1).mean()
+ + targets.xlogy(targets).sum(-1).mean()
+ )
+ acc_train_loss += loss.item() * input.size(0)
+ nb_train_samples += input.size(0)
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ acc_test_loss, nb_test_samples = 0, 0
+ for input, targets in task.policy_batches(split="test"):
+ output_gpt = gpt(mygpt.BracketedSequence(input), with_readout=False).x
+ output = model(output_gpt)
+ loss = (
+ -(output.log_softmax(-1) * targets).sum(-1).mean()
+ + targets.xlogy(targets).sum(-1).mean()
+ )
+ acc_test_loss += loss.item() * input.size(0)
+ nb_test_samples += input.size(0)
+
+ log_string(
+ f"diff_ce {n_epoch} train {acc_train_loss/nb_train_samples} test {acc_test_loss/nb_test_samples}"
+ )
+
+ gpt.train(t)
+
+
+######################################################################
+
+