######################################################################
+def sigma_for_grids(input):
+ l = input.size(1) // 4 - 1
+ sigma = input.new(input.size())
+ r = sigma.view(sigma.size(0), 4, sigma.size(1) // 4)
+ r[:, :, 1:] = (
+ torch.rand(input.size(0), 4, l, device=input.device).sort(dim=2).indices
+ ) + 1
+ r[:, 0] += 0 * l
+ r[:, 1] += 1 * l
+ r[:, 2] += 2 * l
+ r[:, 3] += 3 * l
+ return sigma
+
+
def run_tests(model, quiz_machine, local_device=main_device):
with torch.autograd.no_grad():
model.eval().to(local_device)
for input in tqdm.tqdm(src, dynamic_ncols=True, desc="test"):
input = input.to(local_device)
- sigma = torch.rand(input.size(), device=input.device).sort(dim=1).indices
+ sigma = sigma_for_grids(input)
output = model(mygpt.BracketedSequence(input), sigma).x
loss = F.cross_entropy(output.transpose(1, 2), input)
acc_test_loss += loss.item() * input.size(0)
targets = input
- sigma = torch.rand(input.size(), device=input.device).sort(dim=1).indices
+ sigma = sigma_for_grids(input)
output = model(mygpt.BracketedSequence(input), sigma).x
loss_per_token = F.cross_entropy(
output.transpose(1, 2), targets, reduction="none"