Use cudnn more efficiently.
[dyncnn.git] / dyncnn.lua
1 #!/usr/bin/env luajit
2
3 --[[
4
5    dyncnn is a deep-learning algorithm for the prediction of
6    interacting object dynamics
7
8    Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/
9    Written by Francois Fleuret <francois.fleuret@idiap.ch>
10
11    This file is part of dyncnn.
12
13    dyncnn is free software: you can redistribute it and/or modify it
14    under the terms of the GNU General Public License version 3 as
15    published by the Free Software Foundation.
16
17    dyncnn is distributed in the hope that it will be useful, but
18    WITHOUT ANY WARRANTY; without even the implied warranty of
19    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20    General Public License for more details.
21
22    You should have received a copy of the GNU General Public License
23    along with dyncnn.  If not, see <http://www.gnu.org/licenses/>.
24
25 ]]--
26
27 require 'torch'
28 require 'nn'
29 require 'optim'
30 require 'image'
31 require 'pl'
32
33 ----------------------------------------------------------------------
34
35 local opt = lapp[[
36    --seed                (default 1)               random seed
37
38    --learningStateFile   (default '')
39    --dataDir             (default './data/10p-mg/')
40    --resultDir           (default '/tmp/dyncnn')
41
42    --learningRate        (default -1)
43    --momentum            (default -1)
44    --nbEpochs            (default -1)              nb of epochs for the heavy setting
45
46    --heavy                                         use the heavy configuration
47    --nbChannels          (default -1)              nb of channels in the internal layers
48    --resultFreq          (default 100)
49
50    --noLog                                         supress logging
51
52    --exampleInternals    (default -1)
53 ]]
54
55 ----------------------------------------------------------------------
56
57 commandLine=''
58 for i = 0, #arg do
59    commandLine = commandLine ..  ' \'' .. arg[i] .. '\''
60 end
61
62 ----------------------------------------------------------------------
63
64 colors = sys.COLORS
65
66 global = {}
67
68 function logString(s, c)
69    if global.logFile then
70       global.logFile:write(s)
71       global.logFile:flush()
72    end
73    local c = c or colors.black
74    io.write(c .. s)
75    io.flush()
76 end
77
78 function logCommand(c)
79    logString('[' .. c .. '] -> [' .. sys.execute(c) .. ']\n', colors.blue)
80 end
81
82 logString('commandline: ' .. commandLine .. '\n', colors.blue)
83
84 logCommand('mkdir -v -p ' .. opt.resultDir)
85
86 if not opt.noLog then
87    global.logName = opt.resultDir .. '/log'
88    global.logFile = io.open(global.logName, 'a')
89 end
90
91 ----------------------------------------------------------------------
92
93 alreadyLoggedString = {}
94
95 function logOnce(s)
96    local l = debug.getinfo(1).currentline
97    if not alreadyLoggedString[l] then
98       logString('@line ' .. l .. ' ' .. s, colors.red)
99       alreadyLoggedString[l] = s
100    end
101 end
102
103 ----------------------------------------------------------------------
104
105 nbThreads = os.getenv('TORCH_NB_THREADS') or 1
106
107 useGPU = os.getenv('TORCH_USE_GPU') == 'yes'
108
109 for _, c in pairs({ 'date',
110                     'uname -a',
111                     'git log -1 --format=%H'
112                  })
113 do
114    logCommand(c)
115 end
116
117 logString('useGPU is \'' .. tostring(useGPU) .. '\'.\n')
118
119 logString('nbThreads is \'' .. nbThreads .. '\'.\n')
120
121 ----------------------------------------------------------------------
122
123 torch.setnumthreads(nbThreads)
124 torch.setdefaulttensortype('torch.FloatTensor')
125 torch.manualSeed(opt.seed)
126
127 mynn = {}
128
129 -- To deal elegantly with CPU/GPU
130 local mt = {}
131 function mt.__index(table, key)
132    return (cudnn and cudnn[key]) or (cunn and cunn[key]) or nn[key]
133 end
134 setmetatable(mynn, mt)
135
136 -- These are the tensors that can be kept on the CPU
137 mynn.SlowTensor = torch.Tensor
138 -- These are the tensors that should be moved to the GPU
139 mynn.FastTensor = torch.Tensor
140
141 ----------------------------------------------------------------------
142
143 if useGPU then
144    require 'cutorch'
145    require 'cunn'
146    require 'cudnn'
147
148    mynn.FastTensor = torch.CudaTensor
149
150    if cudnn then
151       cudnn.benchmark = true
152       cudnn.fastest = true
153    end
154 end
155
156 ----------------------------------------------------------------------
157
158 config = {}
159 config.learningRate = 0.1
160 config.momentum = 0
161 config.batchSize = 128
162 config.filterSize = 5
163
164 if opt.heavy then
165
166    logString('Using the heavy configuration.\n')
167    config.nbChannels = 16
168    config.nbBlocks = 4
169    config.nbEpochs = 250
170    config.nbEpochsInit = 100
171    config.nbTrainSamples = 32768
172    config.nbValidationSamples = 1024
173    config.nbTestSamples = 1024
174
175 else
176
177    logString('Using the light configuration.\n')
178    config.nbChannels = 2
179    config.nbBlocks = 2
180    config.nbEpochs = 6
181    config.nbEpochsInit = 3
182    config.nbTrainSamples = 1024
183    config.nbValidationSamples = 1024
184    config.nbTestSamples = 1024
185
186 end
187
188 if opt.nbEpochs > 0 then
189    config.nbEpochs = opt.nbEpochs
190 end
191
192 if opt.nbChannels > 0 then
193    config.nbChannels = opt.nbChannels
194 end
195
196 if opt.learningRate > 0 then
197    config.learningRate = opt.learningRate
198 end
199
200 if opt.momentum >= 0 then
201    config.momentum = opt.momentum
202 end
203
204 ----------------------------------------------------------------------
205
206 function tensorCensus(tensorType, model)
207
208    local nb = {}
209
210    local function countThings(m)
211       for k, i in pairs(m) do
212          if torch.type(i) == tensorType then
213             nb[k] = (nb[k] or 0) + i:nElement()
214          end
215       end
216    end
217
218    model:apply(countThings)
219
220    return nb
221
222 end
223
224 ----------------------------------------------------------------------
225
226 function loadData(first, nb, name)
227    logString('Loading data `' .. name .. '\'.\n')
228
229    local persistentFileName = string.format('%s/persistent_%d_%d.dat',
230                                             opt.dataDir,
231                                             first,
232                                             nb)
233
234    -- This is at what framerate we work. It is greater than 1 so that
235    -- we can keep on disk sequences at a higher frame rate for videos
236    -- and explaining materials
237
238    local frameRate = 4
239
240    local data
241
242    if not path.exists(persistentFileName) then
243       logString(string.format('No persistent data structure, creating it (%d samples).\n', nb))
244       local data = {}
245       data.name = name
246       data.nbSamples = nb
247       data.width = 64
248       data.height = 64
249       data.input = mynn.SlowTensor(data.nbSamples, 2, data.height, data.width)
250       data.target = mynn.SlowTensor(data.nbSamples, 1, data.height, data.width)
251
252       for i = 1, data.nbSamples do
253          local n = i-1 + first-1
254          local prefix = string.format('%s/%03d/dyn_%06d',
255                                       opt.dataDir,
256                                       math.floor(n/1000), n)
257
258          function localLoad(filename, tensor)
259             local tmp
260             tmp = image.load(filename)
261             tmp:mul(-1.0):add(1.0)
262             tensor:copy(torch.max(tmp, 1))
263          end
264
265          localLoad(prefix .. '_world_000.png', data.input[i][1])
266          localLoad(prefix .. '_grab.png',    data.input[i][2])
267          localLoad(string.format('%s_world_%03d.png', prefix, frameRate),
268                    data.target[i][1])
269       end
270
271       data.persistentFileName = persistentFileName
272
273       torch.save(persistentFileName, data)
274    end
275
276    logCommand('sha256sum -b ' .. persistentFileName)
277
278    data = torch.load(persistentFileName)
279
280    return data
281 end
282
283 ----------------------------------------------------------------------
284
285 -- This function gets as input a list of tensors of arbitrary
286 -- dimensions each, but whose two last dimension stands for height x
287 -- width. It creates an image tensor (2d, one channel) with each
288 -- argument tensor unfolded per row.
289
290 function imageFromTensors(bt, signed)
291    local gap = 1
292    local tgap = -1
293    local width = 0
294    local height = gap
295
296    for _, t in pairs(bt) do
297       -- print(t:size())
298       local d = t:dim()
299       local h, w = t:size(d - 1), t:size(d)
300       local n = t:nElement() / (w * h)
301       width = math.max(width, gap + n * (gap + w))
302       height = height + gap + tgap + gap + h
303    end
304
305    local e = torch.Tensor(3, height, width):fill(1.0)
306    local y0 = 1 + gap
307
308    for _, t in pairs(bt) do
309       local d = t:dim()
310       local h, w = t:size(d - 1), t:size(d)
311       local n = t:nElement() / (w * h)
312       local z = t:norm() / math.sqrt(t:nElement())
313
314       local x0 = 1 + gap + math.floor( (width - n * (w + gap)) /2 )
315       local u = torch.Tensor(t:size()):copy(t):resize(n, h, w)
316       for m = 1, n do
317
318          for c = 1, 3 do
319             for y = 0, h+1 do
320                e[c][y0 + y - 1][x0     - 1] = 0.0
321                e[c][y0 + y - 1][x0 + w    ] = 0.0
322             end
323             for x = 0, w+1 do
324                e[c][y0     - 1][x0 + x - 1] = 0.0
325                e[c][y0 + h    ][x0 + x - 1] = 0.0
326             end
327          end
328
329          for y = 1, h do
330             for x = 1, w do
331                local v = u[m][y][x] / z
332                local r, g, b
333                if signed then
334                   if v < -1 then
335                      r, g, b = 0.0, 0.0, 1.0
336                   elseif v > 1 then
337                      r, g, b = 1.0, 0.0, 0.0
338                   elseif v >= 0 then
339                      r, g, b = 1.0, 1.0 - v, 1.0 - v
340                   else
341                      r, g, b = 1.0 + v, 1.0 + v, 1.0
342                   end
343                else
344                   if v <= 0 then
345                      r, g, b = 1.0, 1.0, 1.0
346                   elseif v > 1 then
347                      r, g, b = 0.0, 0.0, 0.0
348                   else
349                      r, g, b = 1.0 - v, 1.0 - v, 1.0 - v
350                   end
351                end
352                e[1][y0 + y - 1][x0 + x - 1] = r
353                e[2][y0 + y - 1][x0 + x - 1] = g
354                e[3][y0 + y - 1][x0 + x - 1] = b
355             end
356          end
357          x0 = x0 + w + gap
358       end
359       y0 = y0 + h + gap + tgap + gap
360    end
361
362    return e
363 end
364
365 function collectAllOutputs(model, collection, which)
366    if torch.type(model) == 'nn.Sequential' then
367       for i = 1, #model.modules do
368          collectAllOutputs(model.modules[i], collection, which)
369       end
370    elseif not which or which[torch.type(model)] then
371       local t = torch.type(model.output)
372       if t == 'torch.FloatTensor' or t == 'torch.CudaTensor' then
373          collection.nb = collection.nb + 1
374          collection.outputs[collection.nb] = model.output
375       end
376    end
377 end
378
379 function saveInternalsImage(model, data, n)
380    -- Explicitely copy to keep input as a mynn.FastTensor
381    local input = mynn.FastTensor(1, 2, data.height, data.width)
382    input:copy(data.input:narrow(1, n, 1))
383
384    local output = model:forward(input)
385
386    local collection = {}
387    collection.outputs = {}
388    collection.nb = 1
389    collection.outputs[collection.nb] = input
390
391    local which = {}
392    which['nn.ReLU'] = true
393    collectAllOutputs(model, collection, which)
394
395    if collection.outputs[collection.nb] ~= model.output then
396       collection.nb = collection.nb + 1
397       collection.outputs[collection.nb] = model.output
398    end
399
400    local fileName = string.format('%s/internals_%s_%06d.png',
401                                   opt.resultDir,
402                                   data.name, n)
403
404    logString('Saving ' .. fileName .. '\n')
405    image.save(fileName, imageFromTensors(collection.outputs))
406 end
407
408 ----------------------------------------------------------------------
409
410 function saveResultImage(model, data, prefix, nbMax, highlight)
411    local l2criterion = nn.MSECriterion()
412
413    if useGPU then
414       logString('Moving the criterion to the GPU.\n')
415       l2criterion:cuda()
416    end
417
418    local prefix = prefix or 'result'
419    local result = torch.Tensor(data.height * 4 + 5, data.width + 2)
420    local input = mynn.FastTensor(1, 2, data.height, data.width)
421    local target = mynn.FastTensor(1, 1, data.height, data.width)
422
423    local nbMax = nbMax or 50
424
425    local nb = math.min(nbMax, data.nbSamples)
426
427    model:evaluate()
428
429    logString(string.format('Write %d result images `%s\' for set `%s\' in %s.\n',
430                            nb, prefix, data.name,
431                            opt.resultDir))
432
433    for n = 1, nb do
434
435       -- Explicitely copy to keep input as a mynn.FastTensor
436       input:copy(data.input:narrow(1, n, 1))
437       target:copy(data.target:narrow(1, n, 1))
438
439       local output = model:forward(input)
440
441       local loss = l2criterion:forward(output, target)
442
443       result:fill(1.0)
444
445       if highlight then
446          for i = 1, data.height do
447             for j = 1, data.width do
448                local v = data.input[n][1][i][j]
449                result[1 + i + 0 * (data.height + 1)][1 + j] = data.input[n][2][i][j]
450                result[1 + i + 1 * (data.height + 1)][1 + j] = v
451                local a = data.target[n][1][i][j]
452                local b = output[1][1][i][j]
453                result[1 + i + 2 * (data.height + 1)][1 + j] =
454                   a * math.min(1, 0.1 + 2.0 * math.abs(a - v))
455                result[1 + i + 3 * (data.height + 1)][1 + j] =
456                   b * math.min(1, 0.1 + 2.0 * math.abs(b - v))
457             end
458          end
459       else
460          for i = 1, data.height do
461             for j = 1, data.width do
462                result[1 + i + 0 * (data.height + 1)][1 + j] = data.input[n][2][i][j]
463                result[1 + i + 1 * (data.height + 1)][1 + j] = data.input[n][1][i][j]
464                result[1 + i + 2 * (data.height + 1)][1 + j] = data.target[n][1][i][j]
465                result[1 + i + 3 * (data.height + 1)][1 + j] = output[1][1][i][j]
466             end
467          end
468       end
469
470       result:mul(-1.0):add(1.0)
471
472       local fileName = string.format('%s/%s_%s_%06d.png',
473                                      opt.resultDir,
474                                      prefix,
475                                      data.name, n)
476
477       logString(string.format('LOSS_ON_SAMPLE %f %s\n', loss, fileName))
478
479       image.save(fileName, result)
480    end
481 end
482
483 ----------------------------------------------------------------------
484
485 function createTower(filterSize, nbChannels, nbBlocks)
486    local tower = mynn.Sequential()
487
488    for b = 1, nbBlocks do
489       local block = mynn.Sequential()
490
491       block:add(mynn.SpatialConvolution(nbChannels,
492                                         nbChannels,
493                                         filterSize, filterSize,
494                                         1, 1,
495                                         (filterSize - 1) / 2, (filterSize - 1) / 2))
496       block:add(mynn.SpatialBatchNormalization(nbChannels))
497       block:add(mynn.ReLU(true))
498
499       block:add(mynn.SpatialConvolution(nbChannels,
500                                         nbChannels,
501                                         filterSize, filterSize,
502                                         1, 1,
503                                         (filterSize - 1) / 2, (filterSize - 1) / 2))
504
505       local parallel = mynn.ConcatTable()
506       parallel:add(block):add(mynn.Identity())
507
508       tower:add(parallel):add(mynn.CAddTable(true))
509
510       tower:add(mynn.SpatialBatchNormalization(nbChannels))
511       tower:add(mynn.ReLU(true))
512    end
513
514    return tower
515 end
516
517 function createModel(filterSize, nbChannels, nbBlocks)
518    local model = mynn.Sequential()
519
520    model:add(mynn.SpatialConvolution(2,
521                                      nbChannels,
522                                      filterSize, filterSize,
523                                      1, 1,
524                                      (filterSize - 1) / 2, (filterSize - 1) / 2))
525
526    model:add(mynn.SpatialBatchNormalization(nbChannels))
527    model:add(mynn.ReLU(true))
528
529    local towerCode   = createTower(filterSize, nbChannels, nbBlocks)
530    local towerDecode = createTower(filterSize, nbChannels, nbBlocks)
531
532    model:add(towerCode)
533    model:add(towerDecode)
534
535    -- Decode to a single channel, which is the final image
536    model:add(mynn.SpatialConvolution(nbChannels,
537                                      1,
538                                      filterSize, filterSize,
539                                      1, 1,
540                                      (filterSize - 1) / 2, (filterSize - 1) / 2))
541
542    return model
543 end
544
545 ----------------------------------------------------------------------
546
547 function fillBatch(data, first, nb, batch, permutation)
548    for k = 1, nb do
549       local i
550       if permutation then
551          i = permutation[first + k - 1]
552       else
553          i = first + k - 1
554       end
555       batch.input[k] = data.input[i]
556       batch.target[k] = data.target[i]
557    end
558 end
559
560 function trainModel(model,
561                     trainData, validationData, nbEpochs, learningRate,
562                     learningStateFile)
563
564    local l2criterion = nn.MSECriterion()
565    local batchSize = config.batchSize
566
567    if useGPU then
568       logString('Moving the criterion to the GPU.\n')
569       l2criterion:cuda()
570    end
571
572    local batch = {}
573    batch.input = mynn.FastTensor(batchSize, 2, trainData.height, trainData.width)
574    batch.target = mynn.FastTensor(batchSize, 1, trainData.height, trainData.width)
575
576    local startingEpoch = 1
577
578    if model.epoch then
579       startingEpoch = model.epoch + 1
580    end
581
582    if model.RNGState then
583       torch.setRNGState(model.RNGState)
584    end
585
586    logString('Starting training.\n')
587
588    local parameters, gradParameters = model:getParameters()
589    logString(string.format('model has %d parameters.\n', parameters:storage():size(1)))
590
591    local averageTrainLoss, averageValidationLoss
592    local trainTime, validationTime
593
594    local sgdState = {
595       learningRate = config.learningRate,
596       momentum = config.momentum,
597       learningRateDecay = 0
598    }
599
600    for e = startingEpoch, nbEpochs do
601
602       model:training()
603
604       local permutation = torch.randperm(trainData.nbSamples)
605
606       local accLoss = 0.0
607       local nbBatches = 0
608       local startTime = sys.clock()
609
610       for b = 1, trainData.nbSamples, batchSize do
611
612          fillBatch(trainData, b, batchSize, batch, permutation)
613
614          local opfunc = function(x)
615             -- Surprisingly copy() needs this check
616             if x ~= parameters then
617                parameters:copy(x)
618             end
619
620             local output = model:forward(batch.input)
621             local loss = l2criterion:forward(output, batch.target)
622
623             local dLossdOutput = l2criterion:backward(output, batch.target)
624             gradParameters:zero()
625             model:backward(batch.input, dLossdOutput)
626
627             accLoss = accLoss + loss
628             nbBatches = nbBatches + 1
629
630             return loss, gradParameters
631          end
632
633          optim.sgd(opfunc, parameters, sgdState)
634
635       end
636
637       trainTime = sys.clock() - startTime
638       averageTrainLoss = accLoss / nbBatches
639
640       ----------------------------------------------------------------------
641       -- Validation losses
642       do
643          model:evaluate()
644
645          local accLoss = 0.0
646          local nbBatches = 0
647          local startTime = sys.clock()
648
649          for b = 1, validationData.nbSamples, batchSize do
650             fillBatch(validationData, b, batchSize, batch)
651             local output = model:forward(batch.input)
652             accLoss = accLoss + l2criterion:forward(output, batch.target)
653             nbBatches = nbBatches + 1
654          end
655
656          validationTime = sys.clock() - startTime
657          averageValidationLoss = accLoss / nbBatches;
658       end
659
660       logString(string.format('Epoch train %0.2fs (%0.2fms / sample), validation %0.2fs (%0.2fms / sample).\n',
661                               trainTime,
662                               1000 * trainTime / trainData.nbSamples,
663                               validationTime,
664                               1000 * validationTime / validationData.nbSamples))
665
666       logString(string.format('LOSS %d %f %f\n', e, averageTrainLoss, averageValidationLoss),
667                 colors.green)
668
669       ----------------------------------------------------------------------
670       -- Save a persistent state so that we can restart from there
671
672       if learningStateFile then
673          model.RNGState = torch.getRNGState()
674          model.epoch = e
675          model:clearState()
676          logString('Writing ' .. learningStateFile .. '.\n')
677          torch.save(learningStateFile, model)
678       end
679
680       ----------------------------------------------------------------------
681       -- Save a duplicate of the persistent state from time to time
682
683       if opt.resultFreq > 0 and e%opt.resultFreq == 0 then
684          torch.save(string.format('%s/epoch_%05d_model', opt.resultDir, e), model)
685          saveResultImage(model, trainData)
686          saveResultImage(model, validationData)
687       end
688
689    end
690
691 end
692
693 function createAndTrainModel(trainData, validationData)
694
695    local model
696
697    local learningStateFile = opt.learningStateFile
698
699    if learningStateFile == '' then
700       learningStateFile = opt.resultDir .. '/learning.state'
701    end
702
703    local gotlearningStateFile
704
705    logString('Using the learning state file ' .. learningStateFile .. '\n')
706
707    if pcall(function () model = torch.load(learningStateFile) end) then
708
709       gotlearningStateFile = true
710
711    else
712
713       model = createModel(config.filterSize, config.nbChannels, config.nbBlocks)
714
715       if useGPU then
716          logString('Moving the model to the GPU.\n')
717          model:cuda()
718       end
719
720    end
721
722    logString(tostring(model) .. '\n')
723
724    if gotlearningStateFile then
725       logString(string.format('Found a learning state with %d epochs finished.\n', model.epoch),
726                 colors.red)
727    end
728
729    if opt.exampleInternals > 0 then
730       saveInternalsImage(model, validationData, opt.exampleInternals)
731       os.exit(0)
732    end
733
734    trainModel(model,
735               trainData, validationData,
736               config.nbEpochs, config.learningRate,
737               learningStateFile)
738
739    return model
740
741 end
742
743 for i, j in pairs(config) do
744    logString('config ' .. i .. ' = \'' .. j ..'\'\n')
745 end
746
747 local trainData = loadData(1, config.nbTrainSamples, 'train')
748 local validationData = loadData(config.nbTrainSamples + 1, config.nbValidationSamples, 'validation')
749 local testData = loadData(config.nbTrainSamples + config.nbValidationSamples + 1, config.nbTestSamples, 'test')
750
751 local model = createAndTrainModel(trainData, validationData)
752
753 saveResultImage(model, trainData)
754 saveResultImage(model, validationData)
755 saveResultImage(model, testData, nil, testData.nbSamples)