Pass the use of cuda to the VignetteSet constructor.
authorFrancois Fleuret <francois@fleuret.org>
Fri, 16 Jun 2017 05:54:04 +0000 (07:54 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Fri, 16 Jun 2017 05:54:04 +0000 (07:54 +0200)
cnn-svrt.py
vignette_set.py

index a2ab1a3..8b8ec12 100755 (executable)
@@ -27,6 +27,8 @@ import math
 
 from colorama import Fore, Back, Style
 
+# Pytorch
+
 import torch
 
 from torch import optim
@@ -36,6 +38,8 @@ from torch import nn
 from torch.nn import functional as fn
 from torchvision import datasets, transforms, utils
 
+# SVRT
+
 from vignette_set import VignetteSet, CompressedVignetteSet
 
 ######################################################################
@@ -165,11 +169,15 @@ for arg in vars(args):
 
 for problem_number in range(1, 24):
     if args.compress_vignettes:
-        train_set = CompressedVignetteSet(problem_number, args.nb_train_batches, args.batch_size)
-        test_set = CompressedVignetteSet(problem_number, args.nb_test_batches, args.batch_size)
+        train_set = CompressedVignetteSet(problem_number, args.nb_train_batches, args.batch_size,
+                                          cuda=torch.cuda.is_available())
+        test_set = CompressedVignetteSet(problem_number, args.nb_test_batches, args.batch_size,
+                                         cuda=torch.cuda.is_available())
     else:
-        train_set = VignetteSet(problem_number, args.nb_train_batches, args.batch_size)
-        test_set = VignetteSet(problem_number, args.nb_test_batches, args.batch_size)
+        train_set = VignetteSet(problem_number, args.nb_train_batches, args.batch_size,
+                                          cuda=torch.cuda.is_available())
+        test_set = VignetteSet(problem_number, args.nb_test_batches, args.batch_size,
+                                          cuda=torch.cuda.is_available())
 
     model = AfrozeShallowNet()
 
index 0ed3d39..695fed3 100755 (executable)
@@ -31,7 +31,8 @@ import svrt
 ######################################################################
 
 class VignetteSet:
-    def __init__(self, problem_number, nb_batches, batch_size):
+    def __init__(self, problem_number, nb_batches, batch_size, cuda = False):
+        self.cuda = cuda
         self.batch_size = batch_size
         self.problem_number = problem_number
         self.nb_batches = nb_batches
@@ -46,7 +47,7 @@ class VignetteSet:
             target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
             input = svrt.generate_vignettes(problem_number, target)
             input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
-            if torch.cuda.is_available():
+            if self.cuda:
                 input = input.cuda()
                 target = target.cuda()
             acc += input.sum() / input.numel()
@@ -65,7 +66,8 @@ class VignetteSet:
 ######################################################################
 
 class CompressedVignetteSet:
-    def __init__(self, problem_number, nb_batches, batch_size):
+    def __init__(self, problem_number, nb_batches, batch_size, cuda = False):
+        self.cuda = cuda
         self.batch_size = batch_size
         self.problem_number = problem_number
         self.nb_batches = nb_batches
@@ -91,7 +93,7 @@ class CompressedVignetteSet:
         input = input.view(self.batch_size, 1, 128, 128).sub_(self.mean).div_(self.std)
         target = self.targets[b]
 
-        if torch.cuda.is_available():
+        if self.cuda:
             input = input.cuda()
             target = target.cuda()