5 dyncnn is a deep-learning algorithm for the prediction of
6 interacting object dynamics
8 Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/
9 Written by Francois Fleuret <francois.fleuret@idiap.ch>
11 This file is part of dyncnn.
13 dyncnn is free software: you can redistribute it and/or modify it
14 under the terms of the GNU General Public License version 3 as
15 published by the Free Software Foundation.
17 dyncnn is distributed in the hope that it will be useful, but
18 WITHOUT ANY WARRANTY; without even the implied warranty of
19 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20 General Public License for more details.
22 You should have received a copy of the GNU General Public License
23 along with dyncnn. If not, see <http://www.gnu.org/licenses/>.
34 ----------------------------------------------------------------------
35 -- Command line arguments
37 local cmd = torch.CmdLine()
39 cmd:text('General setup')
41 cmd:option('-seed', 1, 'initial random seed')
42 cmd:option('-nbThreads', defaultNbThreads, 'how many threads (environment variable TORCH_NB_THREADS)')
43 cmd:option('-useGPU', defaultUseGPU, 'should we use cuda (environment variable TORCH_USE_GPU)')
44 cmd:option('-fastGPU', true, 'should we go as fast as possible, possibly non-deterministically')
49 cmd:option('-resultFreq', 100, 'at which epoch frequency should we save result images')
50 cmd:option('-exampleInternals', '', 'list of comma-separated indices for inner activation images')
51 cmd:option('-noLog', false, 'should we prevent logging')
52 cmd:option('-rundir', '', 'the directory for results')
53 cmd:option('-deltaImages', false, 'should we highlight the difference in result images')
56 cmd:text('Network structure')
58 cmd:option('-filterSize', 5)
59 cmd:option('-nbChannels', 16)
60 cmd:option('-nbBlocks', 8)
65 cmd:option('-nbEpochs', 1000, 'nb of epochs for the heavy setting')
66 cmd:option('-learningRate', 0.1, 'learning rate')
67 cmd:option('-batchSize', 128, 'size of the mini-batches')
68 cmd:option('-nbTrainSamples', 32768)
69 cmd:option('-nbValidationSamples', 1024)
70 cmd:option('-nbTestSamples', 1024)
73 cmd:text('Problem to solve')
75 cmd:option('-dataDir', './data/10p-mg', 'data directory')
77 cmd:addTime('DYNCNN','%F %T')
79 params = cmd:parse(arg)
81 ----------------------------------------------------------------------
88 'git log -1 --format=%H'
94 ----------------------------------------------------------------------
96 function loadData(first, nb, name)
97 print('Loading data `' .. name .. '\'.')
106 data.input = ffnn.SlowTensor(data.nbSamples, 2, data.height, data.width)
107 data.target = ffnn.SlowTensor(data.nbSamples, 1, data.height, data.width)
109 for i = 1, data.nbSamples do
110 local n = i-1 + first-1
111 local frame = image.load(string.format('%s/%03d/dyn_%06d.png',
113 math.floor(n/1000), n))
115 frame:mul(-1.0):add(1.0)
116 frame = frame:max(1):select(1, 1)
118 data.input[i][1]:copy(frame:sub(0 * data.height + 1, 1 * data.height,
119 1 * data.width + 1, 2 * data.width))
121 data.input[i][2]:copy(frame:sub(0 * data.height + 1, 1 * data.height,
122 0 * data.width + 1, 1 * data.width))
124 data.target[i][1]:copy(frame:sub(1 * data.height + 1, 2 * data.height,
125 1 * data.width + 1, 2 * data.width))
131 ----------------------------------------------------------------------
133 function collectAllOutputs(model, collection, which)
134 if torch.type(model) == 'nn.Sequential' then
135 for i = 1, #model.modules do
136 collectAllOutputs(model.modules[i], collection, which)
138 elseif not which or which[torch.type(model)] then
139 if torch.isTensor(model.output) then
140 collection.nb = collection.nb + 1
141 collection.outputs[collection.nb] = model.output
146 function saveInternalsImage(model, data, n)
147 -- Explicitely copy to keep input as a ffnn.FastTensor
148 local input = ffnn.FastTensor(1, 2, data.height, data.width)
149 input:copy(data.input:narrow(1, n, 1))
151 local output = model:forward(input)
153 local collection = {}
154 collection.outputs = {}
156 collection.outputs[collection.nb] = input
158 collectAllOutputs(model, collection,
161 ['cunn.ReLU'] = true,
162 ['cudnn.ReLU'] = true,
166 if collection.outputs[collection.nb] ~= model.output then
167 collection.nb = collection.nb + 1
168 collection.outputs[collection.nb] = model.output
171 local fileName = string.format('%s/internals_%s_%06d.png',
175 print('Saving ' .. fileName)
176 image.save(fileName, imageFromTensors(collection.outputs))
179 ----------------------------------------------------------------------
181 function highlightImage(a, b)
182 if params.deltaImages then
183 local h = torch.csub(a, b):abs()
184 h:div(1/h:max()):mul(0.9):add(0.1)
185 return torch.cmul(a, h)
191 function saveResultImage(model, data, nbMax)
192 local criterion = nn.MSECriterion()
194 if params.useGPU then
195 print('Moving the criterion to the GPU.')
199 local input = ffnn.FastTensor(1, 2, data.height, data.width)
200 local target = ffnn.FastTensor(1, 1, data.height, data.width)
202 local nbMax = nbMax or 50
204 local nb = math.min(nbMax, data.nbSamples)
208 printf('Write %d result images for `%s\'.', nb, data.name)
210 local lossFile = io.open(params.rundir .. '/result_' .. data.name .. '_losses.dat', 'w')
214 -- Explicitely copy to keep input as a ffnn.FastTensor
215 input:copy(data.input:narrow(1, n, 1))
216 target:copy(data.target:narrow(1, n, 1))
218 local output = model:forward(input)
219 local loss = criterion:forward(output, target)
221 output = ffnn.SlowTensor(output:size()):copy(output)
223 -- We use our magical img.lua to create the result images
230 { pad = 1, data.input[n][1] },
231 { pad = 1, data.input[n][2] },
232 { pad = 1, highlightImage(data.target[n][1], data.input[n][1]) },
233 { pad = 1, highlightImage(output[1][1], data.input[n][1]) },
237 local result = combineImages(1.0, comp)
239 result:mul(-1.0):add(1.0)
241 local fileName = string.format('result_%s_%06d.png', data.name, n)
242 image.save(params.rundir .. '/' .. fileName, result)
243 lossFile:write(string.format('%f %s\n', loss, fileName))
247 ----------------------------------------------------------------------
249 function createTower(filterSize, nbChannels, nbBlocks)
253 if nbBlocks == 0 then
255 tower = nn.Identity()
259 tower = ffnn.Sequential()
261 for b = 1, nbBlocks do
262 local block = ffnn.Sequential()
264 block:add(ffnn.SpatialConvolution(nbChannels,
266 filterSize, filterSize,
268 (filterSize - 1) / 2, (filterSize - 1) / 2))
269 block:add(ffnn.SpatialBatchNormalization(nbChannels))
270 block:add(ffnn.ReLU(true))
272 block:add(ffnn.SpatialConvolution(nbChannels,
274 filterSize, filterSize,
276 (filterSize - 1) / 2, (filterSize - 1) / 2))
278 local parallel = ffnn.ConcatTable()
279 parallel:add(block):add(ffnn.Identity())
281 tower:add(parallel):add(ffnn.CAddTable(true))
283 tower:add(ffnn.SpatialBatchNormalization(nbChannels))
284 tower:add(ffnn.ReLU(true))
292 function createModel(imageWidth, imageHeight,
293 filterSize, nbChannels, nbBlocks)
295 local model = ffnn.Sequential()
297 -- Encode the two input channels (grasping image and starting
298 -- configuration) into the internal number of channels
299 model:add(ffnn.SpatialConvolution(2,
301 filterSize, filterSize,
303 (filterSize - 1) / 2, (filterSize - 1) / 2))
305 model:add(ffnn.SpatialBatchNormalization(nbChannels))
306 model:add(ffnn.ReLU(true))
308 -- Add the resnet modules
309 model:add(createTower(filterSize, nbChannels, nbBlocks))
311 -- Decode down to a single channel, which is the final image
312 model:add(ffnn.SpatialConvolution(nbChannels,
314 filterSize, filterSize,
316 (filterSize - 1) / 2, (filterSize - 1) / 2))
321 ----------------------------------------------------------------------
323 function trainModel(model, trainSet, validationSet)
325 local criterion = nn.MSECriterion()
326 local batchSize = params.batchSize
328 local startingEpoch = 1
331 startingEpoch = model.epoch + 1
334 if model.RNGState then
335 printfc(colors.red, 'Using the RNG state from the loaded model.')
336 torch.setRNGState(model.RNGState)
339 if params.useGPU then
340 print('Moving the model and criterion to the GPU.')
345 print('Starting training.')
347 local parameters, gradParameters = model:getParameters()
348 printf('The model has %d parameters.', parameters:storage():size(1))
350 local averageTrainLoss, averageValidationLoss
351 local trainTime, validationTime
353 ----------------------------------------------------------------------
356 learningRate = params.learningRate,
358 learningRateDecay = 0
363 for e = startingEpoch, params.nbEpochs do
367 local permutation = torch.randperm(trainSet.nbSamples)
371 local startTime = sys.clock()
373 for b = 1, trainSet.nbSamples, batchSize do
375 fillBatch(trainSet, b, batch, permutation)
377 local opfunc = function(x)
378 -- Surprisingly, copy() needs this check
379 if x ~= parameters then
383 local output = model:forward(batch.input)
385 local loss = criterion:forward(output, batch.target)
386 local dLossdOutput = criterion:backward(output, batch.target)
388 gradParameters:zero()
389 model:backward(batch.input, dLossdOutput)
391 accLoss = accLoss + loss
392 nbBatches = nbBatches + 1
394 return loss, gradParameters
397 optim.sgd(opfunc, parameters, sgdState)
401 trainTime = sys.clock() - startTime
402 averageTrainLoss = accLoss / nbBatches
404 ----------------------------------------------------------------------
412 local startTime = sys.clock()
414 for b = 1, validationSet.nbSamples, batchSize do
415 fillBatch(validationSet, b, batch)
416 local output = model:forward(batch.input)
417 accLoss = accLoss + criterion:forward(output, batch.target)
418 nbBatches = nbBatches + 1
421 validationTime = sys.clock() - startTime
422 averageValidationLoss = accLoss / nbBatches;
425 ----------------------------------------------------------------------
427 printfc(colors.green,
429 'epoch %d acc_train_loss %f validation_loss %f [train %.02fs total %.02fms / sample, validation %.02fs total %.02fms / sample]',
435 averageValidationLoss,
438 1000 * trainTime / trainSet.nbSamples,
441 1000 * validationTime / validationSet.nbSamples
444 ----------------------------------------------------------------------
445 -- Save a persistent state so that we can restart from there
448 model.RNGState = torch.getRNGState()
450 torch.save(params.rundir .. '/model_last.t7', model)
452 ----------------------------------------------------------------------
453 -- Save a duplicate of the persistent state from time to time
455 if params.resultFreq > 0 and e%params.resultFreq == 0 then
456 torch.save(string.format('%s/model_%04d.t7', params.rundir, e), model)
457 saveResultImage(model, trainSet)
458 saveResultImage(model, validationSet)
465 ----------------------------------------------------------------------
468 local trainSet = loadData(1,
469 params.nbTrainSamples, 'train')
471 local validationSet = loadData(params.nbTrainSamples + 1,
472 params.nbValidationSamples, 'validation')
476 if pcall(function () model = torch.load(params.rundir .. '/model_last.t7') end) then
479 'Found a model with %d epochs completed, starting from there.',
482 if params.exampleInternals ~= '' then
483 for _, i in ipairs(string.split(params.exampleInternals, ',')) do
484 saveInternalsImage(model, validationSet, tonumber(i))
491 model = createModel(trainSet.width, trainSet.height,
492 params.filterSize, params.nbChannels,
497 trainModel(model, trainSet, validationSet)
499 ----------------------------------------------------------------------
502 local testSet = loadData(params.nbTrainSamples + params.nbValidationSamples + 1,
503 params.nbTestSamples, 'test')
505 if params.useGPU then
506 print('Moving the model and criterion to the GPU.')
510 saveResultImage(model, trainSet)
511 saveResultImage(model, validationSet)
512 saveResultImage(model, testSet, 1024)