- model.eval()
-
- nb_test_samples, acc_test_loss = 0, 0.0
-
- for input in task.batches(split="test"):
- input = input.to(device)
-
- # input, loss_masks, true_images = task.excise_last_image(input)
- # input, loss_masks = task.add_true_image(input, true_images, loss_masks)
-
- output = model(mygpt.BracketedSequence(input)).x
- loss = F.cross_entropy(output.transpose(1, 2), input)
- acc_test_loss += loss.item() * input.size(0)
- nb_test_samples += input.size(0)
-
- train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
- test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
-
- log_string(
- f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
- )