5 dyncnn is a deep-learning algorithm for the prediction of
6 interacting object dynamics
8 Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/
9 Written by Francois Fleuret <francois.fleuret@idiap.ch>
11 This file is part of dyncnn.
13 dyncnn is free software: you can redistribute it and/or modify it
14 under the terms of the GNU General Public License version 3 as
15 published by the Free Software Foundation.
17 dyncnn is distributed in the hope that it will be useful, but
18 WITHOUT ANY WARRANTY; without even the implied warranty of
19 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20 General Public License for more details.
22 You should have received a copy of the GNU General Public License
23 along with dyncnn. If not, see <http://www.gnu.org/licenses/>.
34 ----------------------------------------------------------------------
36 function printf(f, ...)
37 print(string.format(f, unpack({...})))
42 function printfc(c, f, ...)
43 print(c .. string.format(f, unpack({...})) .. colors.black)
46 function logCommand(c)
47 print(colors.blue .. '[' .. c .. '] -> [' .. sys.execute(c) .. ']' .. colors.black)
50 ----------------------------------------------------------------------
51 -- Environment variables
53 local defaultNbThreads = 1
54 local defaultUseGPU = false
56 if os.getenv('TORCH_NB_THREADS') then
57 defaultNbThreads = os.getenv('TORCH_NB_THREADS')
58 print('Environment variable TORCH_NB_THREADS is set and equal to ' .. defaultNbThreads)
60 print('Environment variable TORCH_NB_THREADS is not set, default is ' .. defaultNbThreads)
63 if os.getenv('TORCH_USE_GPU') then
64 defaultUseGPU = os.getenv('TORCH_USE_GPU') == 'yes'
65 print('Environment variable TORCH_USE_GPU is set and evaluated as ' .. tostring(defaultUseGPU))
67 print('Environment variable TORCH_USE_GPU is not set, default is ' .. tostring(defaultUseGPU))
70 ----------------------------------------------------------------------
71 -- Command line arguments
73 local cmd = torch.CmdLine()
75 cmd:text('General setup')
77 cmd:option('-seed', 1, 'initial random seed')
78 cmd:option('-nbThreads', defaultNbThreads, 'how many threads (environment variable TORCH_NB_THREADS)')
79 cmd:option('-useGPU', defaultUseGPU, 'should we use cuda (environment variable TORCH_USE_GPU)')
84 cmd:option('-resultFreq', 100, 'at which epoch frequency should we save result images')
85 cmd:option('-exampleInternals', '', 'list of comma-separated indices for inner activation images')
86 cmd:option('-noLog', false, 'should we prevent logging')
87 cmd:option('-rundir', '', 'the directory for results')
88 cmd:option('-deltaImages', false, 'should we highlight the difference in result images')
91 cmd:text('Network structure')
93 cmd:option('-filterSize', 5)
94 cmd:option('-nbChannels', 16)
95 cmd:option('-nbBlocks', 8)
100 cmd:option('-nbEpochs', 1000, 'nb of epochs for the heavy setting')
101 cmd:option('-learningRate', 0.1, 'learning rate')
102 cmd:option('-batchSize', 128, 'size of the mini-batches')
103 cmd:option('-nbTrainSamples', 32768)
104 cmd:option('-nbValidationSamples', 1024)
105 cmd:option('-nbTestSamples', 1024)
108 cmd:text('Problem to solve')
110 cmd:option('-dataDir', './data/10p-mg', 'data directory')
112 ------------------------------
115 cmd:addTime('DYNCNN','%F %T')
117 params = cmd:parse(arg)
119 if params.rundir == '' then
120 params.rundir = cmd:string('exp', params, { })
123 paths.mkdir(params.rundir)
125 if not params.noLog then
126 -- Append to the log if there is one
127 cmd:log(io.open(params.rundir .. '/log', 'a'), params)
130 ----------------------------------------------------------------------
131 -- The experiment per se
133 if params.predictGrasp then
134 params.targetDepth = 2
136 params.targetDepth = 1
139 ----------------------------------------------------------------------
142 torch.setnumthreads(params.nbThreads)
143 torch.setdefaulttensortype('torch.FloatTensor')
144 torch.manualSeed(params.seed)
146 ----------------------------------------------------------------------
147 -- Dealing with the CPU/GPU
149 -- mynn will take entries in that order: mynn, cudnn, cunn, nn
155 __index = function(table, key)
156 return (cudnn and cudnn[key]) or (cunn and cunn[key]) or nn[key]
161 -- These are the tensors that can be kept on the CPU
162 mynn.SlowTensor = torch.Tensor
164 -- These are the tensors that should be moved to the GPU
165 mynn.FastTensor = torch.Tensor
167 if params.useGPU then
171 cudnn.benchmark = true
173 mynn.FastTensor = torch.CudaTensor
176 ----------------------------------------------------------------------
178 function loadData(first, nb, name)
179 print('Loading data `' .. name .. '\'.')
188 data.input = mynn.SlowTensor(data.nbSamples, 2, data.height, data.width)
189 data.target = mynn.SlowTensor(data.nbSamples, 1, data.height, data.width)
191 for i = 1, data.nbSamples do
192 local n = i-1 + first-1
193 local frame = image.load(string.format('%s/%03d/dyn_%06d.png',
195 math.floor(n/1000), n))
197 frame:mul(-1.0):add(1.0)
198 frame = frame:max(1):select(1, 1)
200 data.input[i][1]:copy(frame:sub(0 * data.height + 1, 1 * data.height,
201 1 * data.width + 1, 2 * data.width))
203 data.input[i][2]:copy(frame:sub(0 * data.height + 1, 1 * data.height,
204 0 * data.width + 1, 1 * data.width))
206 data.target[i][1]:copy(frame:sub(1 * data.height + 1, 2 * data.height,
207 1 * data.width + 1, 2 * data.width))
213 ----------------------------------------------------------------------
215 function collectAllOutputs(model, collection, which)
216 if torch.type(model) == 'nn.Sequential' then
217 for i = 1, #model.modules do
218 collectAllOutputs(model.modules[i], collection, which)
220 elseif not which or which[torch.type(model)] then
221 if torch.isTensor(model.output) then
222 collection.nb = collection.nb + 1
223 collection.outputs[collection.nb] = model.output
228 function saveInternalsImage(model, data, n)
229 -- Explicitely copy to keep input as a mynn.FastTensor
230 local input = mynn.FastTensor(1, 2, data.height, data.width)
231 input:copy(data.input:narrow(1, n, 1))
233 local output = model:forward(input)
235 local collection = {}
236 collection.outputs = {}
238 collection.outputs[collection.nb] = input
240 collectAllOutputs(model, collection,
243 ['cunn.ReLU'] = true,
244 ['cudnn.ReLU'] = true,
248 if collection.outputs[collection.nb] ~= model.output then
249 collection.nb = collection.nb + 1
250 collection.outputs[collection.nb] = model.output
253 local fileName = string.format('%s/internals_%s_%06d.png',
257 print('Saving ' .. fileName)
258 image.save(fileName, imageFromTensors(collection.outputs))
261 ----------------------------------------------------------------------
263 function highlightImage(a, b)
264 if params.deltaImages then
265 local h = torch.csub(a, b):abs()
266 h:div(1/h:max()):mul(0.9):add(0.1)
267 return torch.cmul(a, h)
273 function saveResultImage(model, data, nbMax)
274 local criterion = nn.MSECriterion()
276 if params.useGPU then
277 print('Moving the criterion to the GPU.')
281 local input = mynn.FastTensor(1, 2, data.height, data.width)
282 local target = mynn.FastTensor(1, 1, data.height, data.width)
284 local nbMax = nbMax or 50
286 local nb = math.min(nbMax, data.nbSamples)
290 printf('Write %d result images for `%s\'.', nb, data.name)
292 local lossFile = io.open(params.rundir .. '/result_' .. data.name .. '_losses.dat', 'w')
296 -- Explicitely copy to keep input as a mynn.FastTensor
297 input:copy(data.input:narrow(1, n, 1))
298 target:copy(data.target:narrow(1, n, 1))
300 local output = model:forward(input)
301 local loss = criterion:forward(output, target)
303 output = mynn.SlowTensor(output:size()):copy(output)
305 -- We use our magical img.lua to create the result images
312 { pad = 1, data.input[n][1] },
313 { pad = 1, data.input[n][2] },
314 { pad = 1, highlightImage(data.target[n][1], data.input[n][1]) },
315 { pad = 1, highlightImage(output[1][1], data.input[n][1]) },
319 local result = combineImages(1.0, comp)
321 result:mul(-1.0):add(1.0)
323 local fileName = string.format('result_%s_%06d.png', data.name, n)
324 image.save(params.rundir .. '/' .. fileName, result)
325 lossFile:write(string.format('%f %s\n', loss, fileName))
329 ----------------------------------------------------------------------
331 function createTower(filterSize, nbChannels, nbBlocks)
335 if nbBlocks == 0 then
337 tower = nn.Identity()
341 tower = mynn.Sequential()
343 for b = 1, nbBlocks do
344 local block = mynn.Sequential()
346 block:add(mynn.SpatialConvolution(nbChannels,
348 filterSize, filterSize,
350 (filterSize - 1) / 2, (filterSize - 1) / 2))
351 block:add(mynn.SpatialBatchNormalization(nbChannels))
352 block:add(mynn.ReLU(true))
354 block:add(mynn.SpatialConvolution(nbChannels,
356 filterSize, filterSize,
358 (filterSize - 1) / 2, (filterSize - 1) / 2))
360 local parallel = mynn.ConcatTable()
361 parallel:add(block):add(mynn.Identity())
363 tower:add(parallel):add(mynn.CAddTable(true))
365 tower:add(mynn.SpatialBatchNormalization(nbChannels))
366 tower:add(mynn.ReLU(true))
375 function createModel(imageWidth, imageHeight,
376 filterSize, nbChannels, nbBlocks)
378 local model = mynn.Sequential()
380 -- Encode the two input channels (grasping image and starting
381 -- configuration) into the internal number of channels
382 model:add(mynn.SpatialConvolution(2,
384 filterSize, filterSize,
386 (filterSize - 1) / 2, (filterSize - 1) / 2))
388 model:add(mynn.SpatialBatchNormalization(nbChannels))
389 model:add(mynn.ReLU(true))
391 -- Add the resnet modules
392 model:add(createTower(filterSize, nbChannels, nbBlocks))
394 -- Decode down to a single channel, which is the final image
395 model:add(mynn.SpatialConvolution(nbChannels,
397 filterSize, filterSize,
399 (filterSize - 1) / 2, (filterSize - 1) / 2))
404 ----------------------------------------------------------------------
406 function fillBatch(data, first, batch, permutation)
407 local actualBatchSize = math.min(params.batchSize, data.input:size(1) - first + 1)
409 if actualBatchSize ~= batch.input:size(1) then
410 local size = batch.input:size()
411 size[1] = actualBatchSize
412 batch.input:resize(size)
415 if actualBatchSize ~= batch.target:size(1) then
416 local size = batch.target:size()
417 size[1] = actualBatchSize
418 batch.target:resize(size)
421 for k = 1, batch.input:size(1) do
424 i = permutation[first + k - 1]
428 batch.input[k] = data.input[i]
429 batch.target[k] = data.target[i]
433 function trainModel(model, trainSet, validationSet)
435 local criterion = nn.MSECriterion()
436 local batchSize = params.batchSize
439 batch.input = mynn.FastTensor(batchSize, 2, trainSet.height, trainSet.width)
440 batch.target = mynn.FastTensor(batchSize, 1, trainSet.height, trainSet.width)
442 local startingEpoch = 1
445 startingEpoch = model.epoch + 1
448 if model.RNGState then
449 printfc(colors.red, 'Using the RNG state from the loaded model.')
450 torch.setRNGState(model.RNGState)
453 if params.useGPU then
454 print('Moving the model and criterion to the GPU.')
459 print('Starting training.')
461 local parameters, gradParameters = model:getParameters()
462 printf('The model has %d parameters.', parameters:storage():size(1))
464 local averageTrainLoss, averageValidationLoss
465 local trainTime, validationTime
467 ----------------------------------------------------------------------
470 learningRate = params.learningRate,
472 learningRateDecay = 0
475 for e = startingEpoch, params.nbEpochs do
479 local permutation = torch.randperm(trainSet.nbSamples)
483 local startTime = sys.clock()
485 for b = 1, trainSet.nbSamples, batchSize do
487 fillBatch(trainSet, b, batch, permutation)
489 local opfunc = function(x)
490 -- Surprisingly, copy() needs this check
491 if x ~= parameters then
495 local output = model:forward(batch.input)
497 local loss = criterion:forward(output, batch.target)
498 local dLossdOutput = criterion:backward(output, batch.target)
500 gradParameters:zero()
501 model:backward(batch.input, dLossdOutput)
503 accLoss = accLoss + loss
504 nbBatches = nbBatches + 1
506 return loss, gradParameters
509 optim.sgd(opfunc, parameters, sgdState)
513 trainTime = sys.clock() - startTime
514 averageTrainLoss = accLoss / nbBatches
516 ----------------------------------------------------------------------
524 local startTime = sys.clock()
526 for b = 1, validationSet.nbSamples, batchSize do
527 fillBatch(validationSet, b, batch)
528 local output = model:forward(batch.input)
529 accLoss = accLoss + criterion:forward(output, batch.target)
530 nbBatches = nbBatches + 1
533 validationTime = sys.clock() - startTime
534 averageValidationLoss = accLoss / nbBatches;
537 ----------------------------------------------------------------------
539 printfc(colors.green,
541 'epoch %d acc_train_loss %f validation_loss %f [train %.02fs total %.02fms / sample, validation %.02fs total %.02fms / sample]',
547 averageValidationLoss,
550 1000 * trainTime / trainSet.nbSamples,
553 1000 * validationTime / validationSet.nbSamples
556 ----------------------------------------------------------------------
557 -- Save a persistent state so that we can restart from there
560 model.RNGState = torch.getRNGState()
562 torch.save(params.rundir .. '/model_last.t7', model)
564 ----------------------------------------------------------------------
565 -- Save a duplicate of the persistent state from time to time
567 if params.resultFreq > 0 and e%params.resultFreq == 0 then
568 torch.save(string.format('%s/model_%04d.t7', params.rundir, e), model)
569 saveResultImage(model, trainSet)
570 saveResultImage(model, validationSet)
577 function createAndTrainModel(trainSet, validationSet)
579 -- Load the current training state, or create a new model from
582 if pcall(function () model = torch.load(params.rundir .. '/model_last.t7') end) then
585 'Found a model with %d epochs completed, starting from there.',
588 if params.exampleInternals ~= '' then
589 for _, i in ipairs(string.split(params.exampleInternals, ',')) do
590 saveInternalsImage(model, validationSet, tonumber(i))
597 model = createModel(trainSet.width, trainSet.height,
598 params.filterSize, params.nbChannels,
603 trainModel(model, trainSet, validationSet)
609 ----------------------------------------------------------------------
615 'git log -1 --format=%H'
621 local trainSet = loadData(1,
622 params.nbTrainSamples, 'train')
624 local validationSet = loadData(params.nbTrainSamples + 1,
625 params.nbValidationSamples, 'validation')
627 local model = createAndTrainModel(trainSet, validationSet)
629 ----------------------------------------------------------------------
632 local testSet = loadData(params.nbTrainSamples + params.nbValidationSamples + 1,
633 params.nbTestSamples, 'test')
635 if params.useGPU then
636 print('Moving the model and criterion to the GPU.')
640 saveResultImage(model, trainSet)
641 saveResultImage(model, validationSet)
642 saveResultImage(model, testSet, 1024)