-test_input = test_set.test_data.view(-1, 1, 28, 28).float()
-test_target = test_set.test_labels
-
-if torch.cuda.is_available():
- used_MNIST_classes = used_MNIST_classes.cuda()
- train_input, train_target = train_input.cuda(), train_target.cuda()
- test_input, test_target = test_input.cuda(), test_target.cuda()
+test_input = test_set.test_data.view(-1, 1, 28, 28).to(device).float()
+test_target = test_set.test_labels.to(device)