c0503444ec7d728a62947aceb8385b36b02c42a5
[dyncnn.git] / dyncnn.lua
1 #!/usr/bin/env luajit
2
3 --[[
4
5    dyncnn is a deep-learning algorithm for the prediction of
6    interacting object dynamics
7
8    Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/
9    Written by Francois Fleuret <francois.fleuret@idiap.ch>
10
11    This file is part of dyncnn.
12
13    dyncnn is free software: you can redistribute it and/or modify it
14    under the terms of the GNU General Public License version 3 as
15    published by the Free Software Foundation.
16
17    dyncnn is distributed in the hope that it will be useful, but
18    WITHOUT ANY WARRANTY; without even the implied warranty of
19    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20    General Public License for more details.
21
22    You should have received a copy of the GNU General Public License
23    along with dyncnn.  If not, see <http://www.gnu.org/licenses/>.
24
25 ]]--
26
27 require 'torch'
28 require 'nn'
29 require 'optim'
30 require 'image'
31
32 require 'img'
33
34 ----------------------------------------------------------------------
35
36 function printf(f, ...)
37    print(string.format(f, unpack({...})))
38 end
39
40 colors = sys.COLORS
41
42 function printfc(c, f, ...)
43    print(c .. string.format(f, unpack({...})) .. colors.black)
44 end
45
46 function logCommand(c)
47    print(colors.blue .. '[' .. c .. '] -> [' .. sys.execute(c) .. ']' .. colors.black)
48 end
49
50 ----------------------------------------------------------------------
51 -- Environment variables
52
53 local defaultNbThreads = 1
54 local defaultUseGPU = false
55
56 if os.getenv('TORCH_NB_THREADS') then
57    defaultNbThreads = os.getenv('TORCH_NB_THREADS')
58    print('Environment variable TORCH_NB_THREADS is set and equal to ' .. defaultNbThreads)
59 else
60    print('Environment variable TORCH_NB_THREADS is not set, default is ' .. defaultNbThreads)
61 end
62
63 if os.getenv('TORCH_USE_GPU') then
64    defaultUseGPU = os.getenv('TORCH_USE_GPU') == 'yes'
65    print('Environment variable TORCH_USE_GPU is set and evaluated as ' .. tostring(defaultUseGPU))
66 else
67    print('Environment variable TORCH_USE_GPU is not set, default is ' .. tostring(defaultUseGPU))
68 end
69
70 ----------------------------------------------------------------------
71 -- Command line arguments
72
73 local cmd = torch.CmdLine()
74
75 cmd:text('General setup')
76
77 cmd:option('-seed', 1, 'initial random seed')
78 cmd:option('-nbThreads', defaultNbThreads, 'how many threads (environment variable TORCH_NB_THREADS)')
79 cmd:option('-useGPU', defaultUseGPU, 'should we use cuda (environment variable TORCH_USE_GPU)')
80
81 cmd:text('')
82 cmd:text('Log')
83
84 cmd:option('-resultFreq', 100, 'at which epoch frequency should we save result images')
85 cmd:option('-exampleInternals', '', 'list of comma-separated indices for inner activation images')
86 cmd:option('-noLog', false, 'should we prevent logging')
87 cmd:option('-rundir', '', 'the directory for results')
88 cmd:option('-deltaImages', false, 'should we highlight the difference in result images')
89
90 cmd:text('')
91 cmd:text('Network structure')
92
93 cmd:option('-filterSize', 5)
94 cmd:option('-nbChannels', 16)
95 cmd:option('-nbBlocks', 8)
96
97 cmd:text('')
98 cmd:text('Training')
99
100 cmd:option('-nbEpochs', 2000, 'nb of epochs for the heavy setting')
101 cmd:option('-learningRate', 0.1, 'learning rate')
102 cmd:option('-batchSize', 128, 'size of the mini-batches')
103 cmd:option('-nbTrainSamples', 32768)
104 cmd:option('-nbValidationSamples', 1024)
105 cmd:option('-nbTestSamples', 1024)
106
107 cmd:text('')
108 cmd:text('Problem to solve')
109
110 cmd:option('-dataDir', './data/10p-mg', 'data directory')
111
112 ------------------------------
113 -- Log and stuff
114
115 cmd:addTime('DYNCNN','%F %T')
116
117 params = cmd:parse(arg)
118
119 if params.rundir == '' then
120    params.rundir = cmd:string('exp', params, { })
121 end
122
123 paths.mkdir(params.rundir)
124
125 if not params.noLog then
126    -- Append to the log if there is one
127    cmd:log(io.open(params.rundir .. '/log', 'a'), params)
128 end
129
130 ----------------------------------------------------------------------
131 -- The experiment per se
132
133 if params.predictGrasp then
134    params.targetDepth = 2
135 else
136    params.targetDepth = 1
137 end
138
139 ----------------------------------------------------------------------
140 -- Initializations
141
142 torch.setnumthreads(params.nbThreads)
143 torch.setdefaulttensortype('torch.FloatTensor')
144 torch.manualSeed(params.seed)
145
146 ----------------------------------------------------------------------
147 -- Dealing with the CPU/GPU
148
149 -- mynn will take entries in that order: mynn, cudnn, cunn, nn
150
151 mynn = {}
152
153 setmetatable(mynn,
154              {
155                 __index = function(table, key)
156                    return (cudnn and cudnn[key]) or (cunn and cunn[key]) or nn[key]
157                 end
158              }
159 )
160
161 -- These are the tensors that can be kept on the CPU
162 mynn.SlowTensor = torch.Tensor
163
164 -- These are the tensors that should be moved to the GPU
165 mynn.FastTensor = torch.Tensor
166
167 if params.useGPU then
168    require 'cutorch'
169    require 'cunn'
170    require 'cudnn'
171    cudnn.benchmark = true
172    cudnn.fastest = true
173    mynn.FastTensor = torch.CudaTensor
174 end
175
176 ----------------------------------------------------------------------
177
178 function loadData(first, nb, name)
179    print('Loading data `' .. name .. '\'.')
180
181    local data = {}
182
183    data.name = name
184    data.nbSamples = nb
185    data.width = 64
186    data.height = 64
187
188    data.input = mynn.SlowTensor(data.nbSamples, 2, data.height, data.width)
189    data.target = mynn.SlowTensor(data.nbSamples, 1, data.height, data.width)
190
191    for i = 1, data.nbSamples do
192       local n = i-1 + first-1
193       local frame = image.load(string.format('%s/%03d/dyn_%06d.png',
194                                              params.dataDir,
195                                              math.floor(n/1000), n))
196
197       frame:mul(-1.0):add(1.0)
198       frame = frame:max(1):select(1, 1)
199
200       data.input[i][1]:copy(frame:sub(0 * data.height + 1, 1 * data.height,
201                                       1 * data.width  + 1, 2 * data.width))
202
203       data.input[i][2]:copy(frame:sub(0 * data.height + 1, 1 * data.height,
204                                       0 * data.width  + 1, 1 * data.width))
205
206       data.target[i][1]:copy(frame:sub(1 * data.height + 1, 2 * data.height,
207                                        1 * data.width  + 1, 2 * data.width))
208    end
209
210    return data
211 end
212
213 ----------------------------------------------------------------------
214
215 function collectAllOutputs(model, collection, which)
216    if torch.type(model) == 'nn.Sequential' then
217       for i = 1, #model.modules do
218          collectAllOutputs(model.modules[i], collection, which)
219       end
220    elseif not which or which[torch.type(model)] then
221       if torch.isTensor(model.output) then
222          collection.nb = collection.nb + 1
223          collection.outputs[collection.nb] = model.output
224       end
225    end
226 end
227
228 function saveInternalsImage(model, data, n)
229    -- Explicitely copy to keep input as a mynn.FastTensor
230    local input = mynn.FastTensor(1, 2, data.height, data.width)
231    input:copy(data.input:narrow(1, n, 1))
232
233    local output = model:forward(input)
234
235    local collection = {}
236    collection.outputs = {}
237    collection.nb = 1
238    collection.outputs[collection.nb] = input
239
240    collectAllOutputs(model, collection,
241                      {
242                         ['nn.ReLU'] = true,
243                         ['cunn.ReLU'] = true,
244                         ['cudnn.ReLU'] = true,
245                      }
246    )
247
248    if collection.outputs[collection.nb] ~= model.output then
249       collection.nb = collection.nb + 1
250       collection.outputs[collection.nb] = model.output
251    end
252
253    local fileName = string.format('%s/internals_%s_%06d.png',
254                                   params.rundir,
255                                   data.name, n)
256
257    print('Saving ' .. fileName)
258    image.save(fileName, imageFromTensors(collection.outputs))
259 end
260
261 ----------------------------------------------------------------------
262
263 function highlightImage(a, b)
264    if params.deltaImages then
265       local h = torch.csub(a, b):abs()
266       h:div(1/h:max()):mul(0.9):add(0.1)
267       return torch.cmul(a, h)
268    else
269       return a
270    end
271 end
272
273 function saveResultImage(model, data, nbMax)
274    local criterion = nn.MSECriterion()
275
276    if params.useGPU then
277       print('Moving the criterion to the GPU.')
278       criterion:cuda()
279    end
280
281    local input = mynn.FastTensor(1, 2, data.height, data.width)
282    local target = mynn.FastTensor(1, 1, data.height, data.width)
283
284    local nbMax = nbMax or 50
285
286    local nb = math.min(nbMax, data.nbSamples)
287
288    model:evaluate()
289
290    printf('Write %d result images for `%s\'.', nb, data.name)
291
292    local lossFile = io.open(params.rundir .. '/result_' .. data.name .. '_losses.dat', 'w')
293
294    for n = 1, nb do
295
296       -- Explicitely copy to keep input as a mynn.FastTensor
297       input:copy(data.input:narrow(1, n, 1))
298       target:copy(data.target:narrow(1, n, 1))
299
300       local output = model:forward(input)
301       local loss = criterion:forward(output, target)
302
303       output = mynn.SlowTensor(output:size()):copy(output)
304
305       -- We use our magical img.lua to create the result images
306
307       local comp
308
309       comp = {
310          {
311             vertical = true,
312             { pad = 1, data.input[n][1] },
313             { pad = 1, data.input[n][2] },
314             { pad = 1, highlightImage(data.target[n][1], data.input[n][1]) },
315             { pad = 1, highlightImage(output[1][1], data.input[n][1]) },
316          }
317       }
318
319       local result = combineImages(1.0, comp)
320
321       result:mul(-1.0):add(1.0)
322
323       local fileName = string.format('result_%s_%06d.png', data.name, n)
324       image.save(params.rundir .. '/' .. fileName, result)
325       lossFile:write(string.format('%f %s\n', loss, fileName))
326    end
327 end
328
329 ----------------------------------------------------------------------
330
331 function createTower(filterSize, nbChannels, nbBlocks)
332
333    local tower
334
335    if nbBlocks == 0 then
336
337       tower = nn.Identity()
338
339    else
340
341       tower = mynn.Sequential()
342
343       for b = 1, nbBlocks do
344          local block = mynn.Sequential()
345
346          block:add(mynn.SpatialConvolution(nbChannels,
347                                            nbChannels,
348                                            filterSize, filterSize,
349                                            1, 1,
350                                            (filterSize - 1) / 2, (filterSize - 1) / 2))
351          block:add(mynn.SpatialBatchNormalization(nbChannels))
352          block:add(mynn.ReLU(true))
353
354          block:add(mynn.SpatialConvolution(nbChannels,
355                                            nbChannels,
356                                            filterSize, filterSize,
357                                            1, 1,
358                                            (filterSize - 1) / 2, (filterSize - 1) / 2))
359
360          local parallel = mynn.ConcatTable()
361          parallel:add(block):add(mynn.Identity())
362
363          tower:add(parallel):add(mynn.CAddTable(true))
364
365          tower:add(mynn.SpatialBatchNormalization(nbChannels))
366          tower:add(mynn.ReLU(true))
367       end
368
369    end
370
371    return tower
372
373 end
374
375 function createModel(imageWidth, imageHeight,
376                      filterSize, nbChannels, nbBlocks)
377
378    local model = mynn.Sequential()
379
380    -- Encode the two input channels (grasping image and starting
381    -- configuration) into the internal number of channels
382    model:add(mynn.SpatialConvolution(2,
383                                      nbChannels,
384                                      filterSize, filterSize,
385                                      1, 1,
386                                      (filterSize - 1) / 2, (filterSize - 1) / 2))
387
388    model:add(mynn.SpatialBatchNormalization(nbChannels))
389    model:add(mynn.ReLU(true))
390
391    -- Add the resnet modules
392    model:add(createTower(filterSize, nbChannels, nbBlocks))
393
394    -- Decode down to a single channel, which is the final image
395    model:add(mynn.SpatialConvolution(nbChannels,
396                                      1,
397                                      filterSize, filterSize,
398                                      1, 1,
399                                      (filterSize - 1) / 2, (filterSize - 1) / 2))
400
401    return model
402 end
403
404 ----------------------------------------------------------------------
405
406 function fillBatch(data, first, batch, permutation)
407    local actualBatchSize = math.min(params.batchSize, data.input:size(1) - first + 1)
408
409    if actualBatchSize ~= batch.input:size(1) then
410       local size = batch.input:size()
411       size[1] = actualBatchSize
412       batch.input:resize(size)
413    end
414
415    if actualBatchSize ~= batch.target:size(1) then
416       local size = batch.target:size()
417       size[1] = actualBatchSize
418       batch.target:resize(size)
419    end
420
421    for k = 1, batch.input:size(1) do
422       local i
423       if permutation then
424          i = permutation[first + k - 1]
425       else
426          i = first + k - 1
427       end
428       batch.input[k] = data.input[i]
429       batch.target[k] = data.target[i]
430    end
431 end
432
433 function trainModel(model, trainSet, validationSet)
434
435    local criterion = nn.MSECriterion()
436    local batchSize = params.batchSize
437
438    local batch = {}
439    batch.input = mynn.FastTensor(batchSize, 2, trainSet.height, trainSet.width)
440    batch.target = mynn.FastTensor(batchSize, 1, trainSet.height, trainSet.width)
441
442    local startingEpoch = 1
443
444    if model.epoch then
445       startingEpoch = model.epoch + 1
446    end
447
448    if model.RNGState then
449       printfc(colors.red, 'Using the RNG state from the loaded model.')
450       torch.setRNGState(model.RNGState)
451    end
452
453    if params.useGPU then
454       print('Moving the model and criterion to the GPU.')
455       model:cuda()
456       criterion:cuda()
457    end
458
459    print('Starting training.')
460
461    local parameters, gradParameters = model:getParameters()
462    printf('The model has %d parameters.', parameters:storage():size(1))
463
464    local averageTrainLoss, averageValidationLoss
465    local trainTime, validationTime
466
467    ----------------------------------------------------------------------
468
469    local sgdState = {
470       learningRate = params.learningRate,
471       momentum = 0,
472       learningRateDecay = 0
473    }
474
475    for e = startingEpoch, params.nbEpochs do
476
477       model:training()
478
479       local permutation = torch.randperm(trainSet.nbSamples)
480
481       local accLoss = 0.0
482       local nbBatches = 0
483       local startTime = sys.clock()
484
485       for b = 1, trainSet.nbSamples, batchSize do
486
487          fillBatch(trainSet, b, batch, permutation)
488
489          local opfunc = function(x)
490             -- Surprisingly, copy() needs this check
491             if x ~= parameters then
492                parameters:copy(x)
493             end
494
495             local output = model:forward(batch.input)
496
497             local loss = criterion:forward(output, batch.target)
498             local dLossdOutput = criterion:backward(output, batch.target)
499
500             gradParameters:zero()
501             model:backward(batch.input, dLossdOutput)
502
503             accLoss = accLoss + loss
504             nbBatches = nbBatches + 1
505
506             return loss, gradParameters
507          end
508
509          optim.sgd(opfunc, parameters, sgdState)
510
511       end
512
513       trainTime = sys.clock() - startTime
514       averageTrainLoss = accLoss / nbBatches
515
516       ----------------------------------------------------------------------
517       -- Validation losses
518
519       do
520          model:evaluate()
521
522          local accLoss = 0.0
523          local nbBatches = 0
524          local startTime = sys.clock()
525
526          for b = 1, validationSet.nbSamples, batchSize do
527             fillBatch(validationSet, b, batch)
528             local output = model:forward(batch.input)
529             accLoss = accLoss + criterion:forward(output, batch.target)
530             nbBatches = nbBatches + 1
531          end
532
533          validationTime = sys.clock() - startTime
534          averageValidationLoss = accLoss / nbBatches;
535       end
536
537       ----------------------------------------------------------------------
538
539       printfc(colors.green,
540
541               'epoch %d acc_train_loss %f validation_loss %f [train %.02fs total %.02fms / sample, validation %.02fs total %.02fms / sample]',
542
543               e,
544
545               averageTrainLoss,
546
547               averageValidationLoss,
548
549               trainTime,
550               1000 * trainTime / trainSet.nbSamples,
551
552               validationTime,
553               1000 * validationTime / validationSet.nbSamples
554       )
555
556       ----------------------------------------------------------------------
557       -- Save a persistent state so that we can restart from there
558
559       model:clearState()
560       model.RNGState = torch.getRNGState()
561       model.epoch = e
562       torch.save(params.rundir .. '/model_last.t7', model)
563
564       ----------------------------------------------------------------------
565       -- Save a duplicate of the persistent state from time to time
566
567       if params.resultFreq > 0 and e%params.resultFreq == 0 then
568          torch.save(string.format('%s/model_%04d.t7', params.rundir, e), model)
569          saveResultImage(model, trainSet)
570          saveResultImage(model, validationSet)
571       end
572
573    end
574
575 end
576
577 function createAndTrainModel(trainSet, validationSet)
578
579    -- Load the current training state, or create a new model from
580    -- scratch
581
582    if pcall(function () model = torch.load(params.rundir .. '/model_last.t7') end) then
583
584       printfc(colors.red,
585               'Found a model with %d epochs completed, starting from there.',
586               model.epoch)
587
588       if params.exampleInternals ~= '' then
589          for _, i in ipairs(string.split(params.exampleInternals, ',')) do
590             saveInternalsImage(model, validationSet, tonumber(i))
591          end
592          os.exit(0)
593       end
594
595    else
596
597       model = createModel(trainSet.width, trainSet.height,
598                           params.filterSize, params.nbChannels,
599                           params.nbBlocks)
600
601    end
602
603    trainModel(model, trainSet, validationSet)
604
605    return model
606
607 end
608
609 ----------------------------------------------------------------------
610 -- main
611
612 for _, c in pairs({
613       'date',
614       'uname -a',
615       'git log -1 --format=%H'
616                  })
617 do
618    logCommand(c)
619 end
620
621 local trainSet = loadData(1,
622                           params.nbTrainSamples, 'train')
623
624 local validationSet = loadData(params.nbTrainSamples + 1,
625                                params.nbValidationSamples, 'validation')
626
627 local model = createAndTrainModel(trainSet, validationSet)
628
629 ----------------------------------------------------------------------
630 -- Test
631
632 local testSet = loadData(params.nbTrainSamples + params.nbValidationSamples + 1,
633                          params.nbTestSamples, 'test')
634
635 if params.useGPU then
636    print('Moving the model and criterion to the GPU.')
637    model:cuda()
638 end
639
640 saveResultImage(model, trainSet)
641 saveResultImage(model, validationSet)
642 saveResultImage(model, testSet, 1024)