5 dyncnn is a deep-learning algorithm for the prediction of
6 interacting object dynamics
8 Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/
9 Written by Francois Fleuret <francois.fleuret@idiap.ch>
11 This file is part of dyncnn.
13 dyncnn is free software: you can redistribute it and/or modify it
14 under the terms of the GNU General Public License version 3 as
15 published by the Free Software Foundation.
17 dyncnn is distributed in the hope that it will be useful, but
18 WITHOUT ANY WARRANTY; without even the implied warranty of
19 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20 General Public License for more details.
22 You should have received a copy of the GNU General Public License
23 along with dyncnn. If not, see <http://www.gnu.org/licenses/>.
33 ----------------------------------------------------------------------
36 --seed (default 1) random seed
38 --learningStateFile (default '')
39 --dataDir (default './data/10p-mg/')
40 --resultDir (default '/tmp/dyncnn')
42 --learningRate (default -1)
43 --momentum (default -1)
44 --nbEpochs (default -1) nb of epochs for the heavy setting
46 --heavy use the heavy configuration
47 --nbChannels (default -1) nb of channels in the internal layers
48 --resultFreq (default 100)
50 --noLog supress logging
52 --exampleInternals (default -1)
55 ----------------------------------------------------------------------
59 commandLine = commandLine .. ' \'' .. arg[i] .. '\''
62 ----------------------------------------------------------------------
68 function logString(s, c)
69 if global.logFile then
70 global.logFile:write(s)
71 global.logFile:flush()
73 local c = c or colors.black
78 function logCommand(c)
79 logString('[' .. c .. '] -> [' .. sys.execute(c) .. ']\n', colors.blue)
82 logString('commandline: ' .. commandLine .. '\n', colors.blue)
84 logCommand('mkdir -v -p ' .. opt.resultDir)
87 global.logName = opt.resultDir .. '/log'
88 global.logFile = io.open(global.logName, 'a')
91 ----------------------------------------------------------------------
93 alreadyLoggedString = {}
96 local l = debug.getinfo(1).currentline
97 if not alreadyLoggedString[l] then
98 logString('@line ' .. l .. ' ' .. s, colors.red)
99 alreadyLoggedString[l] = s
103 ----------------------------------------------------------------------
105 nbThreads = os.getenv('TORCH_NB_THREADS') or 1
107 useGPU = os.getenv('TORCH_USE_GPU') == 'yes'
109 for _, c in pairs({ 'date',
111 'git log -1 --format=%H'
117 logString('useGPU is \'' .. tostring(useGPU) .. '\'.\n')
119 logString('nbThreads is \'' .. nbThreads .. '\'.\n')
121 ----------------------------------------------------------------------
123 torch.setnumthreads(nbThreads)
124 torch.setdefaulttensortype('torch.FloatTensor')
125 torch.manualSeed(opt.seed)
129 -- By default, mynn returns the entries from nn
131 function mt.__index(table, key)
134 setmetatable(mynn, mt)
136 -- These are the tensors that can be kept on the CPU
137 mynn.SlowTensor = torch.Tensor
138 -- These are the tensors that should be moved to the GPU
139 mynn.FastTensor = torch.Tensor
141 ----------------------------------------------------------------------
147 mynn.FastTensor = torch.CudaTensor
148 mynn.SpatialConvolution = cudnn.SpatialConvolution
151 ----------------------------------------------------------------------
154 config.learningRate = 0.1
156 config.batchSize = 128
157 config.filterSize = 5
161 logString('Using the heavy configuration.\n')
162 config.nbChannels = 16
164 config.nbEpochs = 250
165 config.nbEpochsInit = 100
166 config.nbTrainSamples = 32768
167 config.nbValidationSamples = 1024
168 config.nbTestSamples = 1024
172 logString('Using the light configuration.\n')
173 config.nbChannels = 2
176 config.nbEpochsInit = 3
177 config.nbTrainSamples = 1024
178 config.nbValidationSamples = 1024
179 config.nbTestSamples = 1024
183 if opt.nbEpochs > 0 then
184 config.nbEpochs = opt.nbEpochs
187 if opt.nbChannels > 0 then
188 config.nbChannels = opt.nbChannels
191 if opt.learningRate > 0 then
192 config.learningRate = opt.learningRate
195 if opt.momentum >= 0 then
196 config.momentum = opt.momentum
199 ----------------------------------------------------------------------
201 function tensorCensus(tensorType, model)
205 local function countThings(m)
206 for k, i in pairs(m) do
207 if torch.type(i) == tensorType then
208 nb[k] = (nb[k] or 0) + i:nElement()
213 model:apply(countThings)
219 ----------------------------------------------------------------------
221 function loadData(first, nb, name)
222 logString('Loading data `' .. name .. '\'.\n')
224 local persistentFileName = string.format('%s/persistent_%d_%d.dat',
229 -- This is at what framerate we work. It is greater than 1 so that
230 -- we can keep on disk sequences at a higher frame rate for videos
231 -- and explaining materials
237 if not path.exists(persistentFileName) then
238 logString(string.format('No persistent data structure, creating it (%d samples).\n', nb))
244 data.input = mynn.SlowTensor(data.nbSamples, 2, data.height, data.width)
245 data.target = mynn.SlowTensor(data.nbSamples, 1, data.height, data.width)
247 for i = 1, data.nbSamples do
248 local n = i-1 + first-1
249 local prefix = string.format('%s/%03d/dyn_%06d',
251 math.floor(n/1000), n)
253 function localLoad(filename, tensor)
255 tmp = image.load(filename)
256 tmp:mul(-1.0):add(1.0)
257 tensor:copy(torch.max(tmp, 1))
260 localLoad(prefix .. '_world_000.png', data.input[i][1])
261 localLoad(prefix .. '_grab.png', data.input[i][2])
262 localLoad(string.format('%s_world_%03d.png', prefix, frameRate),
266 data.persistentFileName = persistentFileName
268 torch.save(persistentFileName, data)
271 logCommand('sha256sum -b ' .. persistentFileName)
273 data = torch.load(persistentFileName)
278 ----------------------------------------------------------------------
280 -- This function gets as input a list of tensors of arbitrary
281 -- dimensions each, but whose two last dimension stands for height x
282 -- width. It creates an image tensor (2d, one channel) with each
283 -- argument tensor unfolded per row.
285 function imageFromTensors(bt, signed)
291 for _, t in pairs(bt) do
294 local h, w = t:size(d - 1), t:size(d)
295 local n = t:nElement() / (w * h)
296 width = math.max(width, gap + n * (gap + w))
297 height = height + gap + tgap + gap + h
300 local e = torch.Tensor(3, height, width):fill(1.0)
303 for _, t in pairs(bt) do
305 local h, w = t:size(d - 1), t:size(d)
306 local n = t:nElement() / (w * h)
307 local z = t:norm() / math.sqrt(t:nElement())
309 local x0 = 1 + gap + math.floor( (width - n * (w + gap)) /2 )
310 local u = torch.Tensor(t:size()):copy(t):resize(n, h, w)
315 e[c][y0 + y - 1][x0 - 1] = 0.0
316 e[c][y0 + y - 1][x0 + w ] = 0.0
319 e[c][y0 - 1][x0 + x - 1] = 0.0
320 e[c][y0 + h ][x0 + x - 1] = 0.0
326 local v = u[m][y][x] / z
330 r, g, b = 0.0, 0.0, 1.0
332 r, g, b = 1.0, 0.0, 0.0
334 r, g, b = 1.0, 1.0 - v, 1.0 - v
336 r, g, b = 1.0 + v, 1.0 + v, 1.0
340 r, g, b = 1.0, 1.0, 1.0
342 r, g, b = 0.0, 0.0, 0.0
344 r, g, b = 1.0 - v, 1.0 - v, 1.0 - v
347 e[1][y0 + y - 1][x0 + x - 1] = r
348 e[2][y0 + y - 1][x0 + x - 1] = g
349 e[3][y0 + y - 1][x0 + x - 1] = b
354 y0 = y0 + h + gap + tgap + gap
360 function collectAllOutputs(model, collection, which)
361 if torch.type(model) == 'nn.Sequential' then
362 for i = 1, #model.modules do
363 collectAllOutputs(model.modules[i], collection, which)
365 elseif not which or which[torch.type(model)] then
366 local t = torch.type(model.output)
367 if t == 'torch.FloatTensor' or t == 'torch.CudaTensor' then
368 collection.nb = collection.nb + 1
369 collection.outputs[collection.nb] = model.output
374 function saveInternalsImage(model, data, n)
375 -- Explicitely copy to keep input as a mynn.FastTensor
376 local input = mynn.FastTensor(1, 2, data.height, data.width)
377 input:copy(data.input:narrow(1, n, 1))
379 local output = model:forward(input)
381 local collection = {}
382 collection.outputs = {}
384 collection.outputs[collection.nb] = input
387 which['nn.ReLU'] = true
388 collectAllOutputs(model, collection, which)
390 if collection.outputs[collection.nb] ~= model.output then
391 collection.nb = collection.nb + 1
392 collection.outputs[collection.nb] = model.output
395 local fileName = string.format('%s/internals_%s_%06d.png',
399 logString('Saving ' .. fileName .. '\n')
400 image.save(fileName, imageFromTensors(collection.outputs))
403 ----------------------------------------------------------------------
405 function saveResultImage(model, data, prefix, nbMax, highlight)
406 local l2criterion = nn.MSECriterion()
409 logString('Moving the criterion to the GPU.\n')
413 local prefix = prefix or 'result'
414 local result = torch.Tensor(data.height * 4 + 5, data.width + 2)
415 local input = mynn.FastTensor(1, 2, data.height, data.width)
416 local target = mynn.FastTensor(1, 1, data.height, data.width)
418 local nbMax = nbMax or 50
420 local nb = math.min(nbMax, data.nbSamples)
424 logString(string.format('Write %d result images `%s\' for set `%s\' in %s.\n',
425 nb, prefix, data.name,
430 -- Explicitely copy to keep input as a mynn.FastTensor
431 input:copy(data.input:narrow(1, n, 1))
432 target:copy(data.target:narrow(1, n, 1))
434 local output = model:forward(input)
436 local loss = l2criterion:forward(output, target)
441 for i = 1, data.height do
442 for j = 1, data.width do
443 local v = data.input[n][1][i][j]
444 result[1 + i + 0 * (data.height + 1)][1 + j] = data.input[n][2][i][j]
445 result[1 + i + 1 * (data.height + 1)][1 + j] = v
446 local a = data.target[n][1][i][j]
447 local b = output[1][1][i][j]
448 result[1 + i + 2 * (data.height + 1)][1 + j] =
449 a * math.min(1, 0.1 + 2.0 * math.abs(a - v))
450 result[1 + i + 3 * (data.height + 1)][1 + j] =
451 b * math.min(1, 0.1 + 2.0 * math.abs(b - v))
455 for i = 1, data.height do
456 for j = 1, data.width do
457 result[1 + i + 0 * (data.height + 1)][1 + j] = data.input[n][2][i][j]
458 result[1 + i + 1 * (data.height + 1)][1 + j] = data.input[n][1][i][j]
459 result[1 + i + 2 * (data.height + 1)][1 + j] = data.target[n][1][i][j]
460 result[1 + i + 3 * (data.height + 1)][1 + j] = output[1][1][i][j]
465 result:mul(-1.0):add(1.0)
467 local fileName = string.format('%s/%s_%s_%06d.png',
472 logString(string.format('LOSS_ON_SAMPLE %f %s\n', loss, fileName))
474 image.save(fileName, result)
478 ----------------------------------------------------------------------
480 function createTower(filterSize, nbChannels, nbBlocks)
481 local tower = mynn.Sequential()
483 for b = 1, nbBlocks do
484 local block = mynn.Sequential()
486 block:add(mynn.SpatialConvolution(nbChannels,
488 filterSize, filterSize,
490 (filterSize - 1) / 2, (filterSize - 1) / 2))
491 block:add(mynn.SpatialBatchNormalization(nbChannels))
492 block:add(mynn.ReLU(true))
494 block:add(mynn.SpatialConvolution(nbChannels,
496 filterSize, filterSize,
498 (filterSize - 1) / 2, (filterSize - 1) / 2))
500 local parallel = mynn.ConcatTable()
501 parallel:add(block):add(mynn.Identity())
503 tower:add(parallel):add(mynn.CAddTable(true))
505 tower:add(mynn.SpatialBatchNormalization(nbChannels))
506 tower:add(mynn.ReLU(true))
512 function createModel(filterSize, nbChannels, nbBlocks)
513 local model = mynn.Sequential()
515 model:add(mynn.SpatialConvolution(2,
517 filterSize, filterSize,
519 (filterSize - 1) / 2, (filterSize - 1) / 2))
521 model:add(mynn.SpatialBatchNormalization(nbChannels))
522 model:add(mynn.ReLU(true))
524 local towerCode = createTower(filterSize, nbChannels, nbBlocks)
525 local towerDecode = createTower(filterSize, nbChannels, nbBlocks)
528 model:add(towerDecode)
530 -- Decode to a single channel, which is the final image
531 model:add(mynn.SpatialConvolution(nbChannels,
533 filterSize, filterSize,
535 (filterSize - 1) / 2, (filterSize - 1) / 2))
540 ----------------------------------------------------------------------
542 function fillBatch(data, first, nb, batch, permutation)
546 i = permutation[first + k - 1]
550 batch.input[k] = data.input[i]
551 batch.target[k] = data.target[i]
555 function trainModel(model,
556 trainData, validationData, nbEpochs, learningRate,
559 local l2criterion = nn.MSECriterion()
560 local batchSize = config.batchSize
563 logString('Moving the criterion to the GPU.\n')
568 batch.input = mynn.FastTensor(batchSize, 2, trainData.height, trainData.width)
569 batch.target = mynn.FastTensor(batchSize, 1, trainData.height, trainData.width)
571 local startingEpoch = 1
574 startingEpoch = model.epoch + 1
577 if model.RNGState then
578 torch.setRNGState(model.RNGState)
581 logString('Starting training.\n')
583 local parameters, gradParameters = model:getParameters()
584 logString(string.format('model has %d parameters.\n', parameters:storage():size(1)))
586 local averageTrainLoss, averageValidationLoss
587 local trainTime, validationTime
590 learningRate = config.learningRate,
591 momentum = config.momentum,
592 learningRateDecay = 0
595 for e = startingEpoch, nbEpochs do
599 local permutation = torch.randperm(trainData.nbSamples)
603 local startTime = sys.clock()
605 for b = 1, trainData.nbSamples, batchSize do
607 fillBatch(trainData, b, batchSize, batch, permutation)
609 local opfunc = function(x)
610 -- Surprisingly copy() needs this check
611 if x ~= parameters then
615 local output = model:forward(batch.input)
616 local loss = l2criterion:forward(output, batch.target)
618 local dLossdOutput = l2criterion:backward(output, batch.target)
619 gradParameters:zero()
620 model:backward(batch.input, dLossdOutput)
622 accLoss = accLoss + loss
623 nbBatches = nbBatches + 1
625 return loss, gradParameters
628 optim.sgd(opfunc, parameters, sgdState)
632 trainTime = sys.clock() - startTime
633 averageTrainLoss = accLoss / nbBatches
635 ----------------------------------------------------------------------
642 local startTime = sys.clock()
644 for b = 1, validationData.nbSamples, batchSize do
645 fillBatch(validationData, b, batchSize, batch)
646 local output = model:forward(batch.input)
647 accLoss = accLoss + l2criterion:forward(output, batch.target)
648 nbBatches = nbBatches + 1
651 validationTime = sys.clock() - startTime
652 averageValidationLoss = accLoss / nbBatches;
655 logString(string.format('Epoch train %0.2fs (%0.2fms / sample), validation %0.2fs (%0.2fms / sample).\n',
657 1000 * trainTime / trainData.nbSamples,
659 1000 * validationTime / validationData.nbSamples))
661 logString(string.format('LOSS %d %f %f\n', e, averageTrainLoss, averageValidationLoss),
664 ----------------------------------------------------------------------
665 -- Save a persistent state so that we can restart from there
667 if learningStateFile then
668 model.RNGState = torch.getRNGState()
671 logString('Writing ' .. learningStateFile .. '.\n')
672 torch.save(learningStateFile, model)
675 ----------------------------------------------------------------------
676 -- Save a duplicate of the persistent state from time to time
678 if opt.resultFreq > 0 and e%opt.resultFreq == 0 then
679 torch.save(string.format('%s/epoch_%05d_model', opt.resultDir, e), model)
680 saveResultImage(model, trainData)
681 saveResultImage(model, validationData)
688 function createAndTrainModel(trainData, validationData)
692 local learningStateFile = opt.learningStateFile
694 if learningStateFile == '' then
695 learningStateFile = opt.resultDir .. '/learning.state'
698 local gotlearningStateFile
700 logString('Using the learning state file ' .. learningStateFile .. '\n')
702 if pcall(function () model = torch.load(learningStateFile) end) then
704 gotlearningStateFile = true
708 model = createModel(config.filterSize, config.nbChannels, config.nbBlocks)
711 logString('Moving the model to the GPU.\n')
717 logString(tostring(model) .. '\n')
719 if gotlearningStateFile then
720 logString(string.format('Found a learning state with %d epochs finished.\n', model.epoch),
724 if opt.exampleInternals > 0 then
725 saveInternalsImage(model, validationData, opt.exampleInternals)
730 trainData, validationData,
731 config.nbEpochs, config.learningRate,
738 for i, j in pairs(config) do
739 logString('config ' .. i .. ' = \'' .. j ..'\'\n')
742 local trainData = loadData(1, config.nbTrainSamples, 'train')
743 local validationData = loadData(config.nbTrainSamples + 1, config.nbValidationSamples, 'validation')
744 local testData = loadData(config.nbTrainSamples + config.nbValidationSamples + 1, config.nbTestSamples, 'test')
746 local model = createAndTrainModel(trainData, validationData)
748 saveResultImage(model, trainData)
749 saveResultImage(model, validationData)
750 saveResultImage(model, testData, nil, testData.nbSamples)