Added README.md
[dyncnn.git] / dyncnn.lua
1 #!/usr/bin/env luajit
2
3 --[[
4
5    dyncnn is a deep-learning algorithm for the prediction of
6    interacting object dynamics
7
8    Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/
9    Written by Francois Fleuret <francois.fleuret@idiap.ch>
10
11    This file is part of dyncnn.
12
13    dyncnn is free software: you can redistribute it and/or modify it
14    under the terms of the GNU General Public License version 3 as
15    published by the Free Software Foundation.
16
17    dyncnn is distributed in the hope that it will be useful, but
18    WITHOUT ANY WARRANTY; without even the implied warranty of
19    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20    General Public License for more details.
21
22    You should have received a copy of the GNU General Public License
23    along with dyncnn.  If not, see <http://www.gnu.org/licenses/>.
24
25 ]]--
26
27 require 'torch'
28 require 'nn'
29 require 'optim'
30 require 'image'
31
32 require 'fftb'
33
34 ----------------------------------------------------------------------
35 -- Command line arguments
36
37 local cmd = torch.CmdLine()
38
39 cmd:text('General setup')
40
41 cmd:option('-seed', 1, 'initial random seed')
42 cmd:option('-nbThreads', defaultNbThreads, 'how many threads (environment variable TORCH_NB_THREADS)')
43 cmd:option('-useGPU', defaultUseGPU, 'should we use cuda (environment variable TORCH_USE_GPU)')
44 cmd:option('-fastGPU', true, 'should we go as fast as possible, possibly non-deterministically')
45
46 cmd:text('')
47 cmd:text('Log')
48
49 cmd:option('-resultFreq', 100, 'at which epoch frequency should we save result images')
50 cmd:option('-exampleInternals', '', 'list of comma-separated indices for inner activation images')
51 cmd:option('-noLog', false, 'should we prevent logging')
52 cmd:option('-rundir', '', 'the directory for results')
53 cmd:option('-deltaImages', false, 'should we highlight the difference in result images')
54
55 cmd:text('')
56 cmd:text('Network structure')
57
58 cmd:option('-filterSize', 5)
59 cmd:option('-nbChannels', 16)
60 cmd:option('-nbBlocks', 8)
61
62 cmd:text('')
63 cmd:text('Training')
64
65 cmd:option('-nbEpochs', 1000, 'nb of epochs for the heavy setting')
66 cmd:option('-learningRate', 0.1, 'learning rate')
67 cmd:option('-batchSize', 128, 'size of the mini-batches')
68 cmd:option('-nbTrainSamples', 32768)
69 cmd:option('-nbValidationSamples', 1024)
70 cmd:option('-nbTestSamples', 1024)
71
72 cmd:text('')
73 cmd:text('Problem to solve')
74
75 cmd:option('-dataDir', './data/10p-mg', 'data directory')
76
77 cmd:addTime('DYNCNN','%F %T')
78
79 params = cmd:parse(arg)
80
81 ----------------------------------------------------------------------
82
83 fftbInit(cmd, params)
84
85 for _, c in pairs({
86       'date',
87       'uname -a',
88       'git log -1 --format=%H'
89                  })
90 do
91    logCommand(c)
92 end
93
94 ----------------------------------------------------------------------
95
96 function loadData(first, nb, name)
97    print('Loading data `' .. name .. '\'.')
98
99    local data = {}
100
101    data.name = name
102    data.nbSamples = nb
103    data.width = 64
104    data.height = 64
105
106    data.input = ffnn.SlowTensor(data.nbSamples, 2, data.height, data.width)
107    data.target = ffnn.SlowTensor(data.nbSamples, 1, data.height, data.width)
108
109    for i = 1, data.nbSamples do
110       local n = i-1 + first-1
111       local frame = image.load(string.format('%s/%03d/dyn_%06d.png',
112                                              params.dataDir,
113                                              math.floor(n/1000), n))
114
115       frame:mul(-1.0):add(1.0)
116       frame = frame:max(1):select(1, 1)
117
118       data.input[i][1]:copy(frame:sub(0 * data.height + 1, 1 * data.height,
119                                       1 * data.width  + 1, 2 * data.width))
120
121       data.input[i][2]:copy(frame:sub(0 * data.height + 1, 1 * data.height,
122                                       0 * data.width  + 1, 1 * data.width))
123
124       data.target[i][1]:copy(frame:sub(1 * data.height + 1, 2 * data.height,
125                                        1 * data.width  + 1, 2 * data.width))
126    end
127
128    return data
129 end
130
131 ----------------------------------------------------------------------
132
133 function collectAllOutputs(model, collection, which)
134    if torch.type(model) == 'nn.Sequential' then
135       for i = 1, #model.modules do
136          collectAllOutputs(model.modules[i], collection, which)
137       end
138    elseif not which or which[torch.type(model)] then
139       if torch.isTensor(model.output) then
140          collection.nb = collection.nb + 1
141          collection.outputs[collection.nb] = model.output
142       end
143    end
144 end
145
146 function saveInternalsImage(model, data, n)
147    -- Explicitely copy to keep input as a ffnn.FastTensor
148    local input = ffnn.FastTensor(1, 2, data.height, data.width)
149    input:copy(data.input:narrow(1, n, 1))
150
151    local output = model:forward(input)
152
153    local collection = {}
154    collection.outputs = {}
155    collection.nb = 1
156    collection.outputs[collection.nb] = input
157
158    collectAllOutputs(model, collection,
159                      {
160                         ['nn.ReLU'] = true,
161                         ['cunn.ReLU'] = true,
162                         ['cudnn.ReLU'] = true,
163                      }
164    )
165
166    if collection.outputs[collection.nb] ~= model.output then
167       collection.nb = collection.nb + 1
168       collection.outputs[collection.nb] = model.output
169    end
170
171    local fileName = string.format('%s/internals_%s_%06d.png',
172                                   params.rundir,
173                                   data.name, n)
174
175    print('Saving ' .. fileName)
176    image.save(fileName, imageFromTensors(collection.outputs))
177 end
178
179 ----------------------------------------------------------------------
180
181 function highlightImage(a, b)
182    if params.deltaImages then
183       local h = torch.csub(a, b):abs()
184       h:div(1/h:max()):mul(0.9):add(0.1)
185       return torch.cmul(a, h)
186    else
187       return a
188    end
189 end
190
191 function saveResultImage(model, data, nbMax)
192    local criterion = nn.MSECriterion()
193
194    if params.useGPU then
195       print('Moving the criterion to the GPU.')
196       criterion:cuda()
197    end
198
199    local input = ffnn.FastTensor(1, 2, data.height, data.width)
200    local target = ffnn.FastTensor(1, 1, data.height, data.width)
201
202    local nbMax = nbMax or 50
203
204    local nb = math.min(nbMax, data.nbSamples)
205
206    model:evaluate()
207
208    printf('Write %d result images for `%s\'.', nb, data.name)
209
210    local lossFile = io.open(params.rundir .. '/result_' .. data.name .. '_losses.dat', 'w')
211
212    for n = 1, nb do
213
214       -- Explicitely copy to keep input as a ffnn.FastTensor
215       input:copy(data.input:narrow(1, n, 1))
216       target:copy(data.target:narrow(1, n, 1))
217
218       local output = model:forward(input)
219       local loss = criterion:forward(output, target)
220
221       output = ffnn.SlowTensor(output:size()):copy(output)
222
223       -- We use our magical img.lua to create the result images
224
225       local comp
226
227       comp = {
228          {
229             vertical = true,
230             { pad = 1, data.input[n][1] },
231             { pad = 1, data.input[n][2] },
232             { pad = 1, highlightImage(data.target[n][1], data.input[n][1]) },
233             { pad = 1, highlightImage(output[1][1], data.input[n][1]) },
234          }
235       }
236
237       local result = combineImages(1.0, comp)
238
239       result:mul(-1.0):add(1.0)
240
241       local fileName = string.format('result_%s_%06d.png', data.name, n)
242       image.save(params.rundir .. '/' .. fileName, result)
243       lossFile:write(string.format('%f %s\n', loss, fileName))
244    end
245 end
246
247 ----------------------------------------------------------------------
248
249 function createTower(filterSize, nbChannels, nbBlocks)
250
251    local tower
252
253    if nbBlocks == 0 then
254
255       tower = nn.Identity()
256
257    else
258
259       tower = ffnn.Sequential()
260
261       for b = 1, nbBlocks do
262          local block = ffnn.Sequential()
263
264          block:add(ffnn.SpatialConvolution(nbChannels,
265                                            nbChannels,
266                                            filterSize, filterSize,
267                                            1, 1,
268                                            (filterSize - 1) / 2, (filterSize - 1) / 2))
269          block:add(ffnn.SpatialBatchNormalization(nbChannels))
270          block:add(ffnn.ReLU(true))
271
272          block:add(ffnn.SpatialConvolution(nbChannels,
273                                            nbChannels,
274                                            filterSize, filterSize,
275                                            1, 1,
276                                            (filterSize - 1) / 2, (filterSize - 1) / 2))
277
278          local parallel = ffnn.ConcatTable()
279          parallel:add(block):add(ffnn.Identity())
280
281          tower:add(parallel):add(ffnn.CAddTable(true))
282
283          tower:add(ffnn.SpatialBatchNormalization(nbChannels))
284          tower:add(ffnn.ReLU(true))
285       end
286
287    end
288
289    return tower
290 end
291
292 function createModel(imageWidth, imageHeight,
293                      filterSize, nbChannels, nbBlocks)
294
295    local model = ffnn.Sequential()
296
297    -- Encode the two input channels (grasping image and starting
298    -- configuration) into the internal number of channels
299    model:add(ffnn.SpatialConvolution(2,
300                                      nbChannels,
301                                      filterSize, filterSize,
302                                      1, 1,
303                                      (filterSize - 1) / 2, (filterSize - 1) / 2))
304
305    model:add(ffnn.SpatialBatchNormalization(nbChannels))
306    model:add(ffnn.ReLU(true))
307
308    -- Add the resnet modules
309    model:add(createTower(filterSize, nbChannels, nbBlocks))
310
311    -- Decode down to a single channel, which is the final image
312    model:add(ffnn.SpatialConvolution(nbChannels,
313                                      1,
314                                      filterSize, filterSize,
315                                      1, 1,
316                                      (filterSize - 1) / 2, (filterSize - 1) / 2))
317
318    return model
319 end
320
321 ----------------------------------------------------------------------
322
323 function trainModel(model, trainSet, validationSet)
324
325    local criterion = nn.MSECriterion()
326    local batchSize = params.batchSize
327
328    local startingEpoch = 1
329
330    if model.epoch then
331       startingEpoch = model.epoch + 1
332    end
333
334    if model.RNGState then
335       printfc(colors.red, 'Using the RNG state from the loaded model.')
336       torch.setRNGState(model.RNGState)
337    end
338
339    if params.useGPU then
340       print('Moving the model and criterion to the GPU.')
341       model:cuda()
342       criterion:cuda()
343    end
344
345    print('Starting training.')
346
347    local parameters, gradParameters = model:getParameters()
348    printf('The model has %d parameters.', parameters:storage():size(1))
349
350    local averageTrainLoss, averageValidationLoss
351    local trainTime, validationTime
352
353    ----------------------------------------------------------------------
354
355    local sgdState = {
356       learningRate = params.learningRate,
357       momentum = 0,
358       learningRateDecay = 0
359    }
360
361    local batch = {}
362
363    for e = startingEpoch, params.nbEpochs do
364
365       model:training()
366
367       local permutation = torch.randperm(trainSet.nbSamples)
368
369       local accLoss = 0.0
370       local nbBatches = 0
371       local startTime = sys.clock()
372
373       for b = 1, trainSet.nbSamples, batchSize do
374
375          fillBatch(trainSet, b, batch, permutation)
376
377          local opfunc = function(x)
378             -- Surprisingly, copy() needs this check
379             if x ~= parameters then
380                parameters:copy(x)
381             end
382
383             local output = model:forward(batch.input)
384
385             local loss = criterion:forward(output, batch.target)
386             local dLossdOutput = criterion:backward(output, batch.target)
387
388             gradParameters:zero()
389             model:backward(batch.input, dLossdOutput)
390
391             accLoss = accLoss + loss
392             nbBatches = nbBatches + 1
393
394             return loss, gradParameters
395          end
396
397          optim.sgd(opfunc, parameters, sgdState)
398
399       end
400
401       trainTime = sys.clock() - startTime
402       averageTrainLoss = accLoss / nbBatches
403
404       ----------------------------------------------------------------------
405       -- Validation losses
406
407       do
408          model:evaluate()
409
410          local accLoss = 0.0
411          local nbBatches = 0
412          local startTime = sys.clock()
413
414          for b = 1, validationSet.nbSamples, batchSize do
415             fillBatch(validationSet, b, batch)
416             local output = model:forward(batch.input)
417             accLoss = accLoss + criterion:forward(output, batch.target)
418             nbBatches = nbBatches + 1
419          end
420
421          validationTime = sys.clock() - startTime
422          averageValidationLoss = accLoss / nbBatches;
423       end
424
425       ----------------------------------------------------------------------
426
427       printfc(colors.green,
428
429               'epoch %d acc_train_loss %f validation_loss %f [train %.02fs total %.02fms / sample, validation %.02fs total %.02fms / sample]',
430
431               e,
432
433               averageTrainLoss,
434
435               averageValidationLoss,
436
437               trainTime,
438               1000 * trainTime / trainSet.nbSamples,
439
440               validationTime,
441               1000 * validationTime / validationSet.nbSamples
442       )
443
444       ----------------------------------------------------------------------
445       -- Save a persistent state so that we can restart from there
446
447       model:clearState()
448       model.RNGState = torch.getRNGState()
449       model.epoch = e
450       torch.save(params.rundir .. '/model_last.t7', model)
451
452       ----------------------------------------------------------------------
453       -- Save a duplicate of the persistent state from time to time
454
455       if params.resultFreq > 0 and e%params.resultFreq == 0 then
456          torch.save(string.format('%s/model_%04d.t7', params.rundir, e), model)
457          saveResultImage(model, trainSet)
458          saveResultImage(model, validationSet)
459       end
460
461    end
462
463 end
464
465 ----------------------------------------------------------------------
466 -- main
467
468 local trainSet = loadData(1,
469                           params.nbTrainSamples, 'train')
470
471 local validationSet = loadData(params.nbTrainSamples + 1,
472                                params.nbValidationSamples, 'validation')
473
474 local model
475
476 if pcall(function () model = torch.load(params.rundir .. '/model_last.t7') end) then
477
478    printfc(colors.red,
479            'Found a model with %d epochs completed, starting from there.',
480            model.epoch)
481
482    if params.exampleInternals ~= '' then
483       for _, i in ipairs(string.split(params.exampleInternals, ',')) do
484          saveInternalsImage(model, validationSet, tonumber(i))
485       end
486       os.exit(0)
487    end
488
489 else
490
491    model = createModel(trainSet.width, trainSet.height,
492                        params.filterSize, params.nbChannels,
493                        params.nbBlocks)
494
495 end
496
497 trainModel(model, trainSet, validationSet)
498
499 ----------------------------------------------------------------------
500 -- Test
501
502 local testSet = loadData(params.nbTrainSamples + params.nbValidationSamples + 1,
503                          params.nbTestSamples, 'test')
504
505 if params.useGPU then
506    print('Moving the model and criterion to the GPU.')
507    model:cuda()
508 end
509
510 saveResultImage(model, trainSet)
511 saveResultImage(model, validationSet)
512 saveResultImage(model, testSet, 1024)