Added a basic CNN + the code to test it on all the problems.
[pysvrt.git] / cnn-svrt.py
1 #!/usr/bin/env python-for-pytorch
2
3 import time
4
5 import torch
6
7 from torch import optim
8 from torch import FloatTensor as Tensor
9 from torch.autograd import Variable
10 from torch import nn
11 from torch.nn import functional as fn
12 from torchvision import datasets, transforms, utils
13
14 from _ext import svrt
15
16 ######################################################################
17 # The data
18
19 def generate_set(p, n):
20     target = torch.LongTensor(n).bernoulli_(0.5)
21     input = svrt.generate_vignettes(p, target)
22     input = input.view(input.size(0), 1, input.size(1), input.size(2)).float()
23     return Variable(input), Variable(target)
24
25 ######################################################################
26
27 # 128x128 --conv(9)-> 120x120 --max(4)-> 30x30 --conv(6)-> 25x25 --max(5)-> 5x5
28
29 class Net(nn.Module):
30     def __init__(self):
31         super(Net, self).__init__()
32         self.conv1 = nn.Conv2d(1, 10, kernel_size=9)
33         self.conv2 = nn.Conv2d(10, 20, kernel_size=6)
34         self.fc1 = nn.Linear(500, 100)
35         self.fc2 = nn.Linear(100, 2)
36
37     def forward(self, x):
38         x = fn.relu(fn.max_pool2d(self.conv1(x), kernel_size=4, stride=4))
39         x = fn.relu(fn.max_pool2d(self.conv2(x), kernel_size=5, stride=5))
40         x = x.view(-1, 500)
41         x = fn.relu(self.fc1(x))
42         x = self.fc2(x)
43         return x
44
45 def train_model(train_input, train_target):
46     model, criterion = Net(), nn.CrossEntropyLoss()
47
48     if torch.cuda.is_available():
49         model.cuda()
50         criterion.cuda()
51
52     nb_epochs = 25
53     optimizer, bs = optim.SGD(model.parameters(), lr = 1e-1), 100
54
55     for k in range(0, nb_epochs):
56         for b in range(0, nb_train_samples, bs):
57             output = model.forward(train_input.narrow(0, b, bs))
58             loss = criterion(output, train_target.narrow(0, b, bs))
59             model.zero_grad()
60             loss.backward()
61             optimizer.step()
62
63     return model
64
65 ######################################################################
66
67 def print_test_error(model, test_input, test_target):
68     bs = 100
69     nb_test_errors = 0
70
71     for b in range(0, nb_test_samples, bs):
72         output = model.forward(test_input.narrow(0, b, bs))
73         _, wta = torch.max(output.data, 1)
74
75         for i in range(0, bs):
76             if wta[i][0] != test_target.narrow(0, b, bs).data[i]:
77                 nb_test_errors = nb_test_errors + 1
78
79     print('TEST_ERROR {:.02f}% ({:d}/{:d})'.format(
80         100 * nb_test_errors / nb_test_samples,
81         nb_test_errors,
82         nb_test_samples)
83     )
84
85 ######################################################################
86
87 nb_train_samples = 100000
88 nb_test_samples = 10000
89
90 for p in range(1, 24):
91     print('-- PROBLEM #{:d} --'.format(p))
92
93     t1 = time.time()
94     train_input, train_target = generate_set(p, nb_train_samples)
95     test_input, test_target = generate_set(p, nb_test_samples)
96     if torch.cuda.is_available():
97         train_input, train_target = train_input.cuda(), train_target.cuda()
98         test_input, test_target = test_input.cuda(), test_target.cuda()
99
100     mu, std = train_input.data.mean(), train_input.data.std()
101     train_input.data.sub_(mu).div_(std)
102     test_input.data.sub_(mu).div_(std)
103
104     t2 = time.time()
105     print('[data generation {:.02f}s]'.format(t2 - t1))
106     model = train_model(train_input, train_target)
107
108     t3 = time.time()
109     print('[train {:.02f}s]'.format(t3 - t2))
110     print_test_error(model, test_input, test_target)
111
112     t4 = time.time()
113
114     print('[test {:.02f}s]'.format(t4 - t3))
115     print()
116
117 ######################################################################