from colorama import Fore, Back, Style
+# Pytorch
+
import torch
from torch import optim
from torch.nn import functional as fn
from torchvision import datasets, transforms, utils
+# SVRT
+
from vignette_set import VignetteSet, CompressedVignetteSet
######################################################################
self.conv3 = nn.Conv2d(16, 120, kernel_size=18)
self.fc1 = nn.Linear(120, 84)
self.fc2 = nn.Linear(84, 2)
+ self.name = 'shallownet'
def forward(self, x):
x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=2))
x = self.fc2(x)
return x
+######################################################################
+
def train_model(model, train_set):
batch_size = args.batch_size
criterion = nn.CrossEntropyLoss()
for problem_number in range(1, 24):
if args.compress_vignettes:
- train_set = CompressedVignetteSet(problem_number, args.nb_train_batches, args.batch_size)
- test_set = CompressedVignetteSet(problem_number, args.nb_test_batches, args.batch_size)
+ train_set = CompressedVignetteSet(problem_number, args.nb_train_batches, args.batch_size,
+ cuda=torch.cuda.is_available())
+ test_set = CompressedVignetteSet(problem_number, args.nb_test_batches, args.batch_size,
+ cuda=torch.cuda.is_available())
else:
- train_set = VignetteSet(problem_number, args.nb_train_batches, args.batch_size)
- test_set = VignetteSet(problem_number, args.nb_test_batches, args.batch_size)
+ train_set = VignetteSet(problem_number, args.nb_train_batches, args.batch_size,
+ cuda=torch.cuda.is_available())
+ test_set = VignetteSet(problem_number, args.nb_test_batches, args.batch_size,
+ cuda=torch.cuda.is_available())
model = AfrozeShallowNet()
nb_parameters += p.numel()
log_string('nb_parameters {:d}'.format(nb_parameters))
- model_filename = 'model_' + str(problem_number) + '.param'
+ model_filename = model.name + '_' + str(problem_number) + '_' + str(train_set.nb_batches) + '.param'
try:
model.load_state_dict(torch.load(model_filename))