type = str, default = 'default.log')
parser.add_argument('--nb_exemplar_vignettes',
- type = int, default = -1)
+ type = int, default = 32)
parser.add_argument('--compress_vignettes',
type = distutils.util.strtobool, default = 'True',
help = 'Use lossless compression to reduce the memory footprint')
-parser.add_argument('--deep_model',
- type = distutils.util.strtobool, default = 'True',
- help = 'Use Afroze\'s Alexnet-like deep model')
+parser.add_argument('--model',
+ type = str, default = 'deepnet',
+ help = 'What model to use')
parser.add_argument('--test_loaded_models',
type = distutils.util.strtobool, default = 'False',
# -- full(84x2) -> 2 1
class AfrozeShallowNet(nn.Module):
+ name = 'shallownet'
+
def __init__(self):
super(AfrozeShallowNet, self).__init__()
self.conv1 = nn.Conv2d(1, 6, kernel_size=21)
self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
self.fc1 = nn.Linear(120, 84)
self.fc2 = nn.Linear(84, 2)
- self.name = 'shallownet'
def forward(self, x):
x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
# Afroze's DeepNet
class AfrozeDeepNet(nn.Module):
+
+ name = 'deepnet'
+
def __init__(self):
super(AfrozeDeepNet, self).__init__()
self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
self.fc1 = nn.Linear(1536, 256)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, 2)
- self.name = 'deepnet'
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = fn.max_pool2d(x, kernel_size=2)
+ x = fn.relu(x)
+
+ x = self.conv2(x)
+ x = fn.max_pool2d(x, kernel_size=2)
+ x = fn.relu(x)
+
+ x = self.conv3(x)
+ x = fn.relu(x)
+
+ x = self.conv4(x)
+ x = fn.relu(x)
+
+ x = self.conv5(x)
+ x = fn.max_pool2d(x, kernel_size=2)
+ x = fn.relu(x)
+
+ x = x.view(-1, 1536)
+
+ x = self.fc1(x)
+ x = fn.relu(x)
+
+ x = self.fc2(x)
+ x = fn.relu(x)
+
+ x = self.fc3(x)
+
+ return x
+
+######################################################################
+
+class DeepNet2(nn.Module):
+ name = 'deepnet2'
+
+ def __init__(self):
+ super(DeepNet2, self).__init__()
+ self.conv1 = nn.Conv2d( 1, 32, kernel_size=7, stride=4, padding=3)
+ self.conv2 = nn.Conv2d( 32, 128, kernel_size=5, padding=2)
+ self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
+ self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
+ self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
+ self.fc1 = nn.Linear(2048, 512)
+ self.fc2 = nn.Linear(512, 512)
+ self.fc3 = nn.Linear(256, 2)
def forward(self, x):
x = self.conv1(x)
log_string('using_uncompressed_vignettes')
VignetteSet = svrtset.VignetteSet
+########################################
+model_class = None
+for m in [ AfrozeShallowNet, AfrozeDeepNet, DeepNet2 ]:
+ if args.model == m.name:
+ model_class = m
+ break
+if model_class is None:
+ print('Unknown model ' + args.model)
+ raise
+
+log_string('using model class ' + m.name)
+########################################
+
for problem_number in map(int, args.problems.split(',')):
log_string('############### problem ' + str(problem_number) + ' ###############')
- if args.deep_model:
- model = AfrozeDeepNet()
- else:
- model = AfrozeShallowNet()
+ model = model_class()
if torch.cuda.is_available(): model.cuda()