5 dyncnn is a deep-learning algorithm for the prediction of
6 interacting object dynamics
8 Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/
9 Written by Francois Fleuret <francois.fleuret@idiap.ch>
11 This file is part of dyncnn.
13 dyncnn is free software: you can redistribute it and/or modify it
14 under the terms of the GNU General Public License version 3 as
15 published by the Free Software Foundation.
17 dyncnn is distributed in the hope that it will be useful, but
18 WITHOUT ANY WARRANTY; without even the implied warranty of
19 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20 General Public License for more details.
22 You should have received a copy of the GNU General Public License
23 along with dyncnn. If not, see <http://www.gnu.org/licenses/>.
33 ----------------------------------------------------------------------
36 --seed (default 1) random seed
38 --learningStateFile (default '')
39 --dataDir (default './data/10p-mg/')
40 --resultDir (default '/tmp/dyncnn')
42 --learningRate (default -1)
43 --momentum (default -1)
44 --nbEpochs (default -1) nb of epochs for the heavy setting
46 --heavy use the heavy configuration
47 --nbChannels (default -1) nb of channels in the internal layers
48 --resultFreq (default 100)
50 --noLog supress logging
52 --exampleInternals (default -1)
55 ----------------------------------------------------------------------
59 commandLine = commandLine .. ' \'' .. arg[i] .. '\''
62 ----------------------------------------------------------------------
68 function logString(s, c)
69 if global.logFile then
70 global.logFile:write(s)
71 global.logFile:flush()
73 local c = c or colors.black
78 function logCommand(c)
79 logString('[' .. c .. '] -> [' .. sys.execute(c) .. ']\n', colors.blue)
82 logString('commandline: ' .. commandLine .. '\n', colors.blue)
84 logCommand('mkdir -v -p ' .. opt.resultDir)
87 global.logName = opt.resultDir .. '/log'
88 global.logFile = io.open(global.logName, 'a')
91 ----------------------------------------------------------------------
93 alreadyLoggedString = {}
96 local l = debug.getinfo(1).currentline
97 if not alreadyLoggedString[l] then
98 logString('@line ' .. l .. ' ' .. s, colors.red)
99 alreadyLoggedString[l] = s
103 ----------------------------------------------------------------------
105 nbThreads = os.getenv('TORCH_NB_THREADS') or 1
107 useGPU = os.getenv('TORCH_USE_GPU') == 'yes'
109 for _, c in pairs({ 'date',
111 'git log -1 --format=%H'
117 logString('useGPU is \'' .. tostring(useGPU) .. '\'.\n')
119 logString('nbThreads is \'' .. nbThreads .. '\'.\n')
121 ----------------------------------------------------------------------
123 torch.setnumthreads(nbThreads)
124 torch.setdefaulttensortype('torch.FloatTensor')
125 torch.manualSeed(opt.seed)
129 -- To deal elegantly with CPU/GPU
131 function mt.__index(table, key)
132 return (cudnn and cudnn[key]) or (cunn and cunn[key]) or nn[key]
134 setmetatable(mynn, mt)
136 -- These are the tensors that can be kept on the CPU
137 mynn.SlowTensor = torch.Tensor
138 -- These are the tensors that should be moved to the GPU
139 mynn.FastTensor = torch.Tensor
141 ----------------------------------------------------------------------
148 mynn.FastTensor = torch.CudaTensor
151 cudnn.benchmark = true
156 ----------------------------------------------------------------------
159 config.learningRate = 0.1
161 config.batchSize = 128
162 config.filterSize = 5
166 logString('Using the heavy configuration.\n')
167 config.nbChannels = 16
169 config.nbEpochs = 250
170 config.nbEpochsInit = 100
171 config.nbTrainSamples = 32768
172 config.nbValidationSamples = 1024
173 config.nbTestSamples = 1024
177 logString('Using the light configuration.\n')
178 config.nbChannels = 2
181 config.nbEpochsInit = 3
182 config.nbTrainSamples = 1024
183 config.nbValidationSamples = 1024
184 config.nbTestSamples = 1024
188 if opt.nbEpochs > 0 then
189 config.nbEpochs = opt.nbEpochs
192 if opt.nbChannels > 0 then
193 config.nbChannels = opt.nbChannels
196 if opt.learningRate > 0 then
197 config.learningRate = opt.learningRate
200 if opt.momentum >= 0 then
201 config.momentum = opt.momentum
204 ----------------------------------------------------------------------
206 function tensorCensus(tensorType, model)
210 local function countThings(m)
211 for k, i in pairs(m) do
212 if torch.type(i) == tensorType then
213 nb[k] = (nb[k] or 0) + i:nElement()
218 model:apply(countThings)
224 ----------------------------------------------------------------------
226 function loadData(first, nb, name)
227 logString('Loading data `' .. name .. '\'.\n')
229 local persistentFileName = string.format('%s/persistent_%d_%d.dat',
234 -- This is at what framerate we work. It is greater than 1 so that
235 -- we can keep on disk sequences at a higher frame rate for videos
236 -- and explaining materials
242 if not path.exists(persistentFileName) then
243 logString(string.format('No persistent data structure, creating it (%d samples).\n', nb))
249 data.input = mynn.SlowTensor(data.nbSamples, 2, data.height, data.width)
250 data.target = mynn.SlowTensor(data.nbSamples, 1, data.height, data.width)
252 for i = 1, data.nbSamples do
253 local n = i-1 + first-1
254 local prefix = string.format('%s/%03d/dyn_%06d',
256 math.floor(n/1000), n)
258 function localLoad(filename, tensor)
260 tmp = image.load(filename)
261 tmp:mul(-1.0):add(1.0)
262 tensor:copy(torch.max(tmp, 1))
265 localLoad(prefix .. '_world_000.png', data.input[i][1])
266 localLoad(prefix .. '_grab.png', data.input[i][2])
267 localLoad(string.format('%s_world_%03d.png', prefix, frameRate),
271 data.persistentFileName = persistentFileName
273 torch.save(persistentFileName, data)
276 logCommand('sha256sum -b ' .. persistentFileName)
278 data = torch.load(persistentFileName)
283 ----------------------------------------------------------------------
285 -- This function gets as input a list of tensors of arbitrary
286 -- dimensions each, but whose two last dimension stands for height x
287 -- width. It creates an image tensor (2d, one channel) with each
288 -- argument tensor unfolded per row.
290 function imageFromTensors(bt, signed)
296 for _, t in pairs(bt) do
299 local h, w = t:size(d - 1), t:size(d)
300 local n = t:nElement() / (w * h)
301 width = math.max(width, gap + n * (gap + w))
302 height = height + gap + tgap + gap + h
305 local e = torch.Tensor(3, height, width):fill(1.0)
308 for _, t in pairs(bt) do
310 local h, w = t:size(d - 1), t:size(d)
311 local n = t:nElement() / (w * h)
312 local z = t:norm() / math.sqrt(t:nElement())
314 local x0 = 1 + gap + math.floor( (width - n * (w + gap)) /2 )
315 local u = torch.Tensor(t:size()):copy(t):resize(n, h, w)
320 e[c][y0 + y - 1][x0 - 1] = 0.0
321 e[c][y0 + y - 1][x0 + w ] = 0.0
324 e[c][y0 - 1][x0 + x - 1] = 0.0
325 e[c][y0 + h ][x0 + x - 1] = 0.0
331 local v = u[m][y][x] / z
335 r, g, b = 0.0, 0.0, 1.0
337 r, g, b = 1.0, 0.0, 0.0
339 r, g, b = 1.0, 1.0 - v, 1.0 - v
341 r, g, b = 1.0 + v, 1.0 + v, 1.0
345 r, g, b = 1.0, 1.0, 1.0
347 r, g, b = 0.0, 0.0, 0.0
349 r, g, b = 1.0 - v, 1.0 - v, 1.0 - v
352 e[1][y0 + y - 1][x0 + x - 1] = r
353 e[2][y0 + y - 1][x0 + x - 1] = g
354 e[3][y0 + y - 1][x0 + x - 1] = b
359 y0 = y0 + h + gap + tgap + gap
365 function collectAllOutputs(model, collection, which)
366 if torch.type(model) == 'nn.Sequential' then
367 for i = 1, #model.modules do
368 collectAllOutputs(model.modules[i], collection, which)
370 elseif not which or which[torch.type(model)] then
371 local t = torch.type(model.output)
372 if t == 'torch.FloatTensor' or t == 'torch.CudaTensor' then
373 collection.nb = collection.nb + 1
374 collection.outputs[collection.nb] = model.output
379 function saveInternalsImage(model, data, n)
380 -- Explicitely copy to keep input as a mynn.FastTensor
381 local input = mynn.FastTensor(1, 2, data.height, data.width)
382 input:copy(data.input:narrow(1, n, 1))
384 local output = model:forward(input)
386 local collection = {}
387 collection.outputs = {}
389 collection.outputs[collection.nb] = input
392 which['nn.ReLU'] = true
393 collectAllOutputs(model, collection, which)
395 if collection.outputs[collection.nb] ~= model.output then
396 collection.nb = collection.nb + 1
397 collection.outputs[collection.nb] = model.output
400 local fileName = string.format('%s/internals_%s_%06d.png',
404 logString('Saving ' .. fileName .. '\n')
405 image.save(fileName, imageFromTensors(collection.outputs))
408 ----------------------------------------------------------------------
410 function saveResultImage(model, data, prefix, nbMax, highlight)
411 local l2criterion = nn.MSECriterion()
414 logString('Moving the criterion to the GPU.\n')
418 local prefix = prefix or 'result'
419 local result = torch.Tensor(data.height * 4 + 5, data.width + 2)
420 local input = mynn.FastTensor(1, 2, data.height, data.width)
421 local target = mynn.FastTensor(1, 1, data.height, data.width)
423 local nbMax = nbMax or 50
425 local nb = math.min(nbMax, data.nbSamples)
429 logString(string.format('Write %d result images `%s\' for set `%s\' in %s.\n',
430 nb, prefix, data.name,
435 -- Explicitely copy to keep input as a mynn.FastTensor
436 input:copy(data.input:narrow(1, n, 1))
437 target:copy(data.target:narrow(1, n, 1))
439 local output = model:forward(input)
441 local loss = l2criterion:forward(output, target)
446 for i = 1, data.height do
447 for j = 1, data.width do
448 local v = data.input[n][1][i][j]
449 result[1 + i + 0 * (data.height + 1)][1 + j] = data.input[n][2][i][j]
450 result[1 + i + 1 * (data.height + 1)][1 + j] = v
451 local a = data.target[n][1][i][j]
452 local b = output[1][1][i][j]
453 result[1 + i + 2 * (data.height + 1)][1 + j] =
454 a * math.min(1, 0.1 + 2.0 * math.abs(a - v))
455 result[1 + i + 3 * (data.height + 1)][1 + j] =
456 b * math.min(1, 0.1 + 2.0 * math.abs(b - v))
460 for i = 1, data.height do
461 for j = 1, data.width do
462 result[1 + i + 0 * (data.height + 1)][1 + j] = data.input[n][2][i][j]
463 result[1 + i + 1 * (data.height + 1)][1 + j] = data.input[n][1][i][j]
464 result[1 + i + 2 * (data.height + 1)][1 + j] = data.target[n][1][i][j]
465 result[1 + i + 3 * (data.height + 1)][1 + j] = output[1][1][i][j]
470 result:mul(-1.0):add(1.0)
472 local fileName = string.format('%s/%s_%s_%06d.png',
477 logString(string.format('LOSS_ON_SAMPLE %f %s\n', loss, fileName))
479 image.save(fileName, result)
483 ----------------------------------------------------------------------
485 function createTower(filterSize, nbChannels, nbBlocks)
486 local tower = mynn.Sequential()
488 for b = 1, nbBlocks do
489 local block = mynn.Sequential()
491 block:add(mynn.SpatialConvolution(nbChannels,
493 filterSize, filterSize,
495 (filterSize - 1) / 2, (filterSize - 1) / 2))
496 block:add(mynn.SpatialBatchNormalization(nbChannels))
497 block:add(mynn.ReLU(true))
499 block:add(mynn.SpatialConvolution(nbChannels,
501 filterSize, filterSize,
503 (filterSize - 1) / 2, (filterSize - 1) / 2))
505 local parallel = mynn.ConcatTable()
506 parallel:add(block):add(mynn.Identity())
508 tower:add(parallel):add(mynn.CAddTable(true))
510 tower:add(mynn.SpatialBatchNormalization(nbChannels))
511 tower:add(mynn.ReLU(true))
517 function createModel(filterSize, nbChannels, nbBlocks)
518 local model = mynn.Sequential()
520 model:add(mynn.SpatialConvolution(2,
522 filterSize, filterSize,
524 (filterSize - 1) / 2, (filterSize - 1) / 2))
526 model:add(mynn.SpatialBatchNormalization(nbChannels))
527 model:add(mynn.ReLU(true))
529 local towerCode = createTower(filterSize, nbChannels, nbBlocks)
530 local towerDecode = createTower(filterSize, nbChannels, nbBlocks)
533 model:add(towerDecode)
535 -- Decode to a single channel, which is the final image
536 model:add(mynn.SpatialConvolution(nbChannels,
538 filterSize, filterSize,
540 (filterSize - 1) / 2, (filterSize - 1) / 2))
545 ----------------------------------------------------------------------
547 function fillBatch(data, first, nb, batch, permutation)
551 i = permutation[first + k - 1]
555 batch.input[k] = data.input[i]
556 batch.target[k] = data.target[i]
560 function trainModel(model,
561 trainData, validationData, nbEpochs, learningRate,
564 local l2criterion = nn.MSECriterion()
565 local batchSize = config.batchSize
568 logString('Moving the criterion to the GPU.\n')
573 batch.input = mynn.FastTensor(batchSize, 2, trainData.height, trainData.width)
574 batch.target = mynn.FastTensor(batchSize, 1, trainData.height, trainData.width)
576 local startingEpoch = 1
579 startingEpoch = model.epoch + 1
582 if model.RNGState then
583 torch.setRNGState(model.RNGState)
586 logString('Starting training.\n')
588 local parameters, gradParameters = model:getParameters()
589 logString(string.format('model has %d parameters.\n', parameters:storage():size(1)))
591 local averageTrainLoss, averageValidationLoss
592 local trainTime, validationTime
595 learningRate = config.learningRate,
596 momentum = config.momentum,
597 learningRateDecay = 0
600 for e = startingEpoch, nbEpochs do
604 local permutation = torch.randperm(trainData.nbSamples)
608 local startTime = sys.clock()
610 for b = 1, trainData.nbSamples, batchSize do
612 fillBatch(trainData, b, batchSize, batch, permutation)
614 local opfunc = function(x)
615 -- Surprisingly copy() needs this check
616 if x ~= parameters then
620 local output = model:forward(batch.input)
621 local loss = l2criterion:forward(output, batch.target)
623 local dLossdOutput = l2criterion:backward(output, batch.target)
624 gradParameters:zero()
625 model:backward(batch.input, dLossdOutput)
627 accLoss = accLoss + loss
628 nbBatches = nbBatches + 1
630 return loss, gradParameters
633 optim.sgd(opfunc, parameters, sgdState)
637 trainTime = sys.clock() - startTime
638 averageTrainLoss = accLoss / nbBatches
640 ----------------------------------------------------------------------
647 local startTime = sys.clock()
649 for b = 1, validationData.nbSamples, batchSize do
650 fillBatch(validationData, b, batchSize, batch)
651 local output = model:forward(batch.input)
652 accLoss = accLoss + l2criterion:forward(output, batch.target)
653 nbBatches = nbBatches + 1
656 validationTime = sys.clock() - startTime
657 averageValidationLoss = accLoss / nbBatches;
660 logString(string.format('Epoch train %0.2fs (%0.2fms / sample), validation %0.2fs (%0.2fms / sample).\n',
662 1000 * trainTime / trainData.nbSamples,
664 1000 * validationTime / validationData.nbSamples))
666 logString(string.format('LOSS %d %f %f\n', e, averageTrainLoss, averageValidationLoss),
669 ----------------------------------------------------------------------
670 -- Save a persistent state so that we can restart from there
672 if learningStateFile then
673 model.RNGState = torch.getRNGState()
676 logString('Writing ' .. learningStateFile .. '.\n')
677 torch.save(learningStateFile, model)
680 ----------------------------------------------------------------------
681 -- Save a duplicate of the persistent state from time to time
683 if opt.resultFreq > 0 and e%opt.resultFreq == 0 then
684 torch.save(string.format('%s/epoch_%05d_model', opt.resultDir, e), model)
685 saveResultImage(model, trainData)
686 saveResultImage(model, validationData)
693 function createAndTrainModel(trainData, validationData)
697 local learningStateFile = opt.learningStateFile
699 if learningStateFile == '' then
700 learningStateFile = opt.resultDir .. '/learning.state'
703 local gotlearningStateFile
705 logString('Using the learning state file ' .. learningStateFile .. '\n')
707 if pcall(function () model = torch.load(learningStateFile) end) then
709 gotlearningStateFile = true
713 model = createModel(config.filterSize, config.nbChannels, config.nbBlocks)
716 logString('Moving the model to the GPU.\n')
722 logString(tostring(model) .. '\n')
724 if gotlearningStateFile then
725 logString(string.format('Found a learning state with %d epochs finished.\n', model.epoch),
729 if opt.exampleInternals > 0 then
730 saveInternalsImage(model, validationData, opt.exampleInternals)
735 trainData, validationData,
736 config.nbEpochs, config.learningRate,
743 for i, j in pairs(config) do
744 logString('config ' .. i .. ' = \'' .. j ..'\'\n')
747 local trainData = loadData(1, config.nbTrainSamples, 'train')
748 local validationData = loadData(config.nbTrainSamples + 1, config.nbValidationSamples, 'validation')
749 local testData = loadData(config.nbTrainSamples + config.nbValidationSamples + 1, config.nbTestSamples, 'test')
751 local model = createAndTrainModel(trainData, validationData)
753 saveResultImage(model, trainData)
754 saveResultImage(model, validationData)
755 saveResultImage(model, testData, nil, testData.nbSamples)