Update.
authorFrancois Fleuret <francois@fleuret.org>
Thu, 22 Jun 2017 06:49:10 +0000 (08:49 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Thu, 22 Jun 2017 06:49:10 +0000 (08:49 +0200)
cnn-svrt.py

index 7fe2db2..227d9b4 100755 (executable)
@@ -77,15 +77,15 @@ parser.add_argument('--log_file',
                     type = str, default = 'default.log')
 
 parser.add_argument('--nb_exemplar_vignettes',
-                    type = int, default = -1)
+                    type = int, default = 32)
 
 parser.add_argument('--compress_vignettes',
                     type = distutils.util.strtobool, default = 'True',
                     help = 'Use lossless compression to reduce the memory footprint')
 
-parser.add_argument('--deep_model',
-                    type = distutils.util.strtobool, default = 'True',
-                    help = 'Use Afroze\'s Alexnet-like deep model')
+parser.add_argument('--model',
+                    type = str, default = 'deepnet',
+                    help = 'What model to use')
 
 parser.add_argument('--test_loaded_models',
                     type = distutils.util.strtobool, default = 'False',
@@ -146,6 +146,8 @@ def log_string(s, remark = ''):
 # -- full(84x2)        -> 2          1
 
 class AfrozeShallowNet(nn.Module):
+    name = 'shallownet'
+
     def __init__(self):
         super(AfrozeShallowNet, self).__init__()
         self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
@@ -153,7 +155,6 @@ class AfrozeShallowNet(nn.Module):
         self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
         self.fc1 = nn.Linear(120, 84)
         self.fc2 = nn.Linear(84, 2)
-        self.name = 'shallownet'
 
     def forward(self, x):
         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
@@ -169,6 +170,9 @@ class AfrozeShallowNet(nn.Module):
 # Afroze's DeepNet
 
 class AfrozeDeepNet(nn.Module):
+
+    name = 'deepnet'
+
     def __init__(self):
         super(AfrozeDeepNet, self).__init__()
         self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
@@ -179,7 +183,53 @@ class AfrozeDeepNet(nn.Module):
         self.fc1 = nn.Linear(1536, 256)
         self.fc2 = nn.Linear(256, 256)
         self.fc3 = nn.Linear(256, 2)
-        self.name = 'deepnet'
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = fn.max_pool2d(x, kernel_size=2)
+        x = fn.relu(x)
+
+        x = self.conv2(x)
+        x = fn.max_pool2d(x, kernel_size=2)
+        x = fn.relu(x)
+
+        x = self.conv3(x)
+        x = fn.relu(x)
+
+        x = self.conv4(x)
+        x = fn.relu(x)
+
+        x = self.conv5(x)
+        x = fn.max_pool2d(x, kernel_size=2)
+        x = fn.relu(x)
+
+        x = x.view(-1, 1536)
+
+        x = self.fc1(x)
+        x = fn.relu(x)
+
+        x = self.fc2(x)
+        x = fn.relu(x)
+
+        x = self.fc3(x)
+
+        return x
+
+######################################################################
+
+class DeepNet2(nn.Module):
+    name = 'deepnet2'
+
+    def __init__(self):
+        super(DeepNet2, self).__init__()
+        self.conv1 = nn.Conv2d(  1,  32, kernel_size=7, stride=4, padding=3)
+        self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2)
+        self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
+        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
+        self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
+        self.fc1 = nn.Linear(2048, 512)
+        self.fc2 = nn.Linear(512, 512)
+        self.fc3 = nn.Linear(256, 2)
 
     def forward(self, x):
         x = self.conv1(x)
@@ -331,14 +381,24 @@ else:
     log_string('using_uncompressed_vignettes')
     VignetteSet = svrtset.VignetteSet
 
+########################################
+model_class = None
+for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2 ]:
+    if args.model == m.name:
+        model_class = m
+        break
+if model_class is None:
+    print('Unknown model ' + args.model)
+    raise
+
+log_string('using model class ' + m.name)
+########################################
+
 for problem_number in map(int, args.problems.split(',')):
 
     log_string('############### problem ' + str(problem_number) + ' ###############')
 
-    if args.deep_model:
-        model = AfrozeDeepNet()
-    else:
-        model = AfrozeShallowNet()
+    model = model_class()
 
     if torch.cuda.is_available(): model.cuda()