Pass the use of cuda to the VignetteSet constructor.
[pysvrt.git] / cnn-svrt.py
index a2ab1a3..8b8ec12 100755 (executable)
@@ -27,6 +27,8 @@ import math
 
 from colorama import Fore, Back, Style
 
+# Pytorch
+
 import torch
 
 from torch import optim
@@ -36,6 +38,8 @@ from torch import nn
 from torch.nn import functional as fn
 from torchvision import datasets, transforms, utils
 
+# SVRT
+
 from vignette_set import VignetteSet, CompressedVignetteSet
 
 ######################################################################
@@ -165,11 +169,15 @@ for arg in vars(args):
 
 for problem_number in range(1, 24):
     if args.compress_vignettes:
-        train_set = CompressedVignetteSet(problem_number, args.nb_train_batches, args.batch_size)
-        test_set = CompressedVignetteSet(problem_number, args.nb_test_batches, args.batch_size)
+        train_set = CompressedVignetteSet(problem_number, args.nb_train_batches, args.batch_size,
+                                          cuda=torch.cuda.is_available())
+        test_set = CompressedVignetteSet(problem_number, args.nb_test_batches, args.batch_size,
+                                         cuda=torch.cuda.is_available())
     else:
-        train_set = VignetteSet(problem_number, args.nb_train_batches, args.batch_size)
-        test_set = VignetteSet(problem_number, args.nb_test_batches, args.batch_size)
+        train_set = VignetteSet(problem_number, args.nb_train_batches, args.batch_size,
+                                          cuda=torch.cuda.is_available())
+        test_set = VignetteSet(problem_number, args.nb_test_batches, args.batch_size,
+                                          cuda=torch.cuda.is_available())
 
     model = AfrozeShallowNet()