--- /dev/null
+#!/usr/bin/env luajit
+
+--[[
+
+ dyncnn is a deep-learning algorithm for the prediction of
+ interacting object dynamics
+
+ Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/
+ Written by Francois Fleuret <francois.fleuret@idiap.ch>
+
+ This file is part of dyncnn.
+
+ dyncnn is free software: you can redistribute it and/or modify it
+ under the terms of the GNU General Public License version 3 as
+ published by the Free Software Foundation.
+
+ dyncnn is distributed in the hope that it will be useful, but
+ WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with dyncnn. If not, see <http://www.gnu.org/licenses/>.
+
+]]--
+
+require 'torch'
+require 'nn'
+require 'optim'
+require 'image'
+require 'pl'
+
+----------------------------------------------------------------------
+
+local opt = lapp[[
+ --seed (default 1) random seed
+
+ --learningStateFile (default '')
+ --dataDir (default './data/10p-mg/')
+ --resultDir (default '/tmp/dyncnn')
+
+ --learningRate (default -1)
+ --momentum (default -1)
+ --nbEpochs (default -1) nb of epochs for the heavy setting
+
+ --heavy use the heavy configuration
+ --nbChannels (default -1) nb of channels in the internal layers
+ --resultFreq (default 100)
+
+ --noLog supress logging
+
+ --exampleInternals (default -1)
+]]
+
+----------------------------------------------------------------------
+
+commandLine=''
+for i = 0, #arg do
+ commandLine = commandLine .. ' \'' .. arg[i] .. '\''
+end
+
+----------------------------------------------------------------------
+
+colors = sys.COLORS
+
+global = {}
+
+function logString(s, c)
+ if global.logFile then
+ global.logFile:write(s)
+ global.logFile:flush()
+ end
+ local c = c or colors.black
+ io.write(c .. s)
+ io.flush()
+end
+
+function logCommand(c)
+ logString('[' .. c .. '] -> [' .. sys.execute(c) .. ']\n', colors.blue)
+end
+
+logString('commandline: ' .. commandLine .. '\n', colors.blue)
+
+logCommand('mkdir -v -p ' .. opt.resultDir)
+
+if not opt.noLog then
+ global.logName = opt.resultDir .. '/log'
+ global.logFile = io.open(global.logName, 'a')
+end
+
+----------------------------------------------------------------------
+
+alreadyLoggedString = {}
+
+function logOnce(s)
+ local l = debug.getinfo(1).currentline
+ if not alreadyLoggedString[l] then
+ logString('@line ' .. l .. ' ' .. s, colors.red)
+ alreadyLoggedString[l] = s
+ end
+end
+
+----------------------------------------------------------------------
+
+nbThreads = os.getenv('TORCH_NB_THREADS') or 1
+
+useGPU = os.getenv('TORCH_USE_GPU') == 'yes'
+
+for _, c in pairs({ 'date',
+ 'uname -a',
+ 'git log -1 --format=%H'
+ })
+do
+ logCommand(c)
+end
+
+logString('useGPU is \'' .. tostring(useGPU) .. '\'.\n')
+
+logString('nbThreads is \'' .. nbThreads .. '\'.\n')
+
+----------------------------------------------------------------------
+
+torch.setnumthreads(nbThreads)
+torch.setdefaulttensortype('torch.FloatTensor')
+torch.manualSeed(opt.seed)
+
+mynn = {}
+
+-- By default, mynn returns the entries from nn
+local mt = {}
+function mt.__index(table, key)
+ return nn[key]
+end
+setmetatable(mynn, mt)
+
+-- These are the tensors that can be kept on the CPU
+mynn.SlowTensor = torch.Tensor
+-- These are the tensors that should be moved to the GPU
+mynn.FastTensor = torch.Tensor
+
+----------------------------------------------------------------------
+
+if useGPU then
+ require 'cutorch'
+ require 'cunn'
+ require 'cudnn'
+ mynn.FastTensor = torch.CudaTensor
+ mynn.SpatialConvolution = cudnn.SpatialConvolution
+end
+
+----------------------------------------------------------------------
+
+config = {}
+config.learningRate = 0.1
+config.momentum = 0
+config.batchSize = 128
+config.filterSize = 5
+
+if opt.heavy then
+
+ logString('Using the heavy configuration.\n')
+ config.nbChannels = 16
+ config.nbBlocks = 4
+ config.nbEpochs = 250
+ config.nbEpochsInit = 100
+ config.nbTrainSamples = 32768
+ config.nbValidationSamples = 1024
+ config.nbTestSamples = 1024
+
+else
+
+ logString('Using the light configuration.\n')
+ config.nbChannels = 2
+ config.nbBlocks = 2
+ config.nbEpochs = 6
+ config.nbEpochsInit = 3
+ config.nbTrainSamples = 1024
+ config.nbValidationSamples = 1024
+ config.nbTestSamples = 1024
+
+end
+
+if opt.nbEpochs > 0 then
+ config.nbEpochs = opt.nbEpochs
+end
+
+if opt.nbChannels > 0 then
+ config.nbChannels = opt.nbChannels
+end
+
+if opt.learningRate > 0 then
+ config.learningRate = opt.learningRate
+end
+
+if opt.momentum >= 0 then
+ config.momentum = opt.momentum
+end
+
+----------------------------------------------------------------------
+
+function tensorCensus(tensorType, model)
+
+ local nb = {}
+
+ local function countThings(m)
+ for k, i in pairs(m) do
+ if torch.type(i) == tensorType then
+ nb[k] = (nb[k] or 0) + i:nElement()
+ end
+ end
+ end
+
+ model:apply(countThings)
+
+ return nb
+
+end
+
+----------------------------------------------------------------------
+
+function loadData(first, nb, name)
+ logString('Loading data `' .. name .. '\'.\n')
+
+ local persistentFileName = string.format('%s/persistent_%d_%d.dat',
+ opt.dataDir,
+ first,
+ nb)
+
+ -- This is at what framerate we work. It is greater than 1 so that
+ -- we can keep on disk sequences at a higher frame rate for videos
+ -- and explaining materials
+
+ local frameRate = 4
+
+ local data
+
+ if not path.exists(persistentFileName) then
+ logString(string.format('No persistent data structure, creating it (%d samples).\n', nb))
+ local data = {}
+ data.name = name
+ data.nbSamples = nb
+ data.width = 64
+ data.height = 64
+ data.input = mynn.SlowTensor(data.nbSamples, 2, data.height, data.width)
+ data.target = mynn.SlowTensor(data.nbSamples, 1, data.height, data.width)
+
+ for i = 1, data.nbSamples do
+ local n = i-1 + first-1
+ local prefix = string.format('%s/%03d/dyn_%06d',
+ opt.dataDir,
+ math.floor(n/1000), n)
+
+ function localLoad(filename, tensor)
+ local tmp
+ tmp = image.load(filename)
+ tmp:mul(-1.0):add(1.0)
+ tensor:copy(torch.max(tmp, 1))
+ end
+
+ localLoad(prefix .. '_world_000.png', data.input[i][1])
+ localLoad(prefix .. '_grab.png', data.input[i][2])
+ localLoad(string.format('%s_world_%03d.png', prefix, frameRate),
+ data.target[i][1])
+ end
+
+ data.persistentFileName = persistentFileName
+
+ torch.save(persistentFileName, data)
+ end
+
+ logCommand('sha256sum -b ' .. persistentFileName)
+
+ data = torch.load(persistentFileName)
+
+ return data
+end
+
+----------------------------------------------------------------------
+
+-- This function gets as input a list of tensors of arbitrary
+-- dimensions each, but whose two last dimension stands for height x
+-- width. It creates an image tensor (2d, one channel) with each
+-- argument tensor unfolded per row.
+
+function imageFromTensors(bt, signed)
+ local gap = 1
+ local tgap = -1
+ local width = 0
+ local height = gap
+
+ for _, t in pairs(bt) do
+ -- print(t:size())
+ local d = t:dim()
+ local h, w = t:size(d - 1), t:size(d)
+ local n = t:nElement() / (w * h)
+ width = math.max(width, gap + n * (gap + w))
+ height = height + gap + tgap + gap + h
+ end
+
+ local e = torch.Tensor(3, height, width):fill(1.0)
+ local y0 = 1 + gap
+
+ for _, t in pairs(bt) do
+ local d = t:dim()
+ local h, w = t:size(d - 1), t:size(d)
+ local n = t:nElement() / (w * h)
+ local z = t:norm() / math.sqrt(t:nElement())
+
+ local x0 = 1 + gap + math.floor( (width - n * (w + gap)) /2 )
+ local u = torch.Tensor(t:size()):copy(t):resize(n, h, w)
+ for m = 1, n do
+
+ for c = 1, 3 do
+ for y = 0, h+1 do
+ e[c][y0 + y - 1][x0 - 1] = 0.0
+ e[c][y0 + y - 1][x0 + w ] = 0.0
+ end
+ for x = 0, w+1 do
+ e[c][y0 - 1][x0 + x - 1] = 0.0
+ e[c][y0 + h ][x0 + x - 1] = 0.0
+ end
+ end
+
+ for y = 1, h do
+ for x = 1, w do
+ local v = u[m][y][x] / z
+ local r, g, b
+ if signed then
+ if v < -1 then
+ r, g, b = 0.0, 0.0, 1.0
+ elseif v > 1 then
+ r, g, b = 1.0, 0.0, 0.0
+ elseif v >= 0 then
+ r, g, b = 1.0, 1.0 - v, 1.0 - v
+ else
+ r, g, b = 1.0 + v, 1.0 + v, 1.0
+ end
+ else
+ if v <= 0 then
+ r, g, b = 1.0, 1.0, 1.0
+ elseif v > 1 then
+ r, g, b = 0.0, 0.0, 0.0
+ else
+ r, g, b = 1.0 - v, 1.0 - v, 1.0 - v
+ end
+ end
+ e[1][y0 + y - 1][x0 + x - 1] = r
+ e[2][y0 + y - 1][x0 + x - 1] = g
+ e[3][y0 + y - 1][x0 + x - 1] = b
+ end
+ end
+ x0 = x0 + w + gap
+ end
+ y0 = y0 + h + gap + tgap + gap
+ end
+
+ return e
+end
+
+function collectAllOutputs(model, collection, which)
+ if torch.type(model) == 'nn.Sequential' then
+ for i = 1, #model.modules do
+ collectAllOutputs(model.modules[i], collection, which)
+ end
+ elseif not which or which[torch.type(model)] then
+ local t = torch.type(model.output)
+ if t == 'torch.FloatTensor' or t == 'torch.CudaTensor' then
+ collection.nb = collection.nb + 1
+ collection.outputs[collection.nb] = model.output
+ end
+ end
+end
+
+function saveInternalsImage(model, data, n)
+ -- Explicitely copy to keep input as a mynn.FastTensor
+ local input = mynn.FastTensor(1, 2, data.height, data.width)
+ input:copy(data.input:narrow(1, n, 1))
+
+ local output = model:forward(input)
+
+ local collection = {}
+ collection.outputs = {}
+ collection.nb = 1
+ collection.outputs[collection.nb] = input
+
+ local which = {}
+ which['nn.ReLU'] = true
+ collectAllOutputs(model, collection, which)
+
+ if collection.outputs[collection.nb] ~= model.output then
+ collection.nb = collection.nb + 1
+ collection.outputs[collection.nb] = model.output
+ end
+
+ local fileName = string.format('%s/internals_%s_%06d.png',
+ opt.resultDir,
+ data.name, n)
+
+ logString('Saving ' .. fileName .. '\n')
+ image.save(fileName, imageFromTensors(collection.outputs))
+end
+
+----------------------------------------------------------------------
+
+function saveResultImage(model, data, prefix, nbMax, highlight)
+ local l2criterion = nn.MSECriterion()
+
+ if useGPU then
+ logString('Moving the criterion to the GPU.\n')
+ l2criterion:cuda()
+ end
+
+ local prefix = prefix or 'result'
+ local result = torch.Tensor(data.height * 4 + 5, data.width + 2)
+ local input = mynn.FastTensor(1, 2, data.height, data.width)
+ local target = mynn.FastTensor(1, 1, data.height, data.width)
+
+ local nbMax = nbMax or 50
+
+ local nb = math.min(nbMax, data.nbSamples)
+
+ model:evaluate()
+
+ logString(string.format('Write %d result images `%s\' for set `%s\' in %s.\n',
+ nb, prefix, data.name,
+ opt.resultDir))
+
+ for n = 1, nb do
+
+ -- Explicitely copy to keep input as a mynn.FastTensor
+ input:copy(data.input:narrow(1, n, 1))
+ target:copy(data.target:narrow(1, n, 1))
+
+ local output = model:forward(input)
+
+ local loss = l2criterion:forward(output, target)
+
+ result:fill(1.0)
+
+ if highlight then
+ for i = 1, data.height do
+ for j = 1, data.width do
+ local v = data.input[n][1][i][j]
+ result[1 + i + 0 * (data.height + 1)][1 + j] = data.input[n][2][i][j]
+ result[1 + i + 1 * (data.height + 1)][1 + j] = v
+ local a = data.target[n][1][i][j]
+ local b = output[1][1][i][j]
+ result[1 + i + 2 * (data.height + 1)][1 + j] =
+ a * math.min(1, 0.1 + 2.0 * math.abs(a - v))
+ result[1 + i + 3 * (data.height + 1)][1 + j] =
+ b * math.min(1, 0.1 + 2.0 * math.abs(b - v))
+ end
+ end
+ else
+ for i = 1, data.height do
+ for j = 1, data.width do
+ result[1 + i + 0 * (data.height + 1)][1 + j] = data.input[n][2][i][j]
+ result[1 + i + 1 * (data.height + 1)][1 + j] = data.input[n][1][i][j]
+ result[1 + i + 2 * (data.height + 1)][1 + j] = data.target[n][1][i][j]
+ result[1 + i + 3 * (data.height + 1)][1 + j] = output[1][1][i][j]
+ end
+ end
+ end
+
+ result:mul(-1.0):add(1.0)
+
+ local fileName = string.format('%s/%s_%s_%06d.png',
+ opt.resultDir,
+ prefix,
+ data.name, n)
+
+ logString(string.format('LOSS_ON_SAMPLE %f %s\n', loss, fileName))
+
+ image.save(fileName, result)
+ end
+end
+
+----------------------------------------------------------------------
+
+function createTower(filterSize, nbChannels, nbBlocks)
+ local tower = mynn.Sequential()
+
+ for b = 1, nbBlocks do
+ local block = mynn.Sequential()
+
+ block:add(mynn.SpatialConvolution(nbChannels,
+ nbChannels,
+ filterSize, filterSize,
+ 1, 1,
+ (filterSize - 1) / 2, (filterSize - 1) / 2))
+ block:add(mynn.SpatialBatchNormalization(nbChannels))
+ block:add(mynn.ReLU(true))
+
+ block:add(mynn.SpatialConvolution(nbChannels,
+ nbChannels,
+ filterSize, filterSize,
+ 1, 1,
+ (filterSize - 1) / 2, (filterSize - 1) / 2))
+
+ local parallel = mynn.ConcatTable()
+ parallel:add(block):add(mynn.Identity())
+
+ tower:add(parallel):add(mynn.CAddTable(true))
+
+ tower:add(mynn.SpatialBatchNormalization(nbChannels))
+ tower:add(mynn.ReLU(true))
+ end
+
+ return tower
+end
+
+function createModel(filterSize, nbChannels, nbBlocks)
+ local model = mynn.Sequential()
+
+ model:add(mynn.SpatialConvolution(2,
+ nbChannels,
+ filterSize, filterSize,
+ 1, 1,
+ (filterSize - 1) / 2, (filterSize - 1) / 2))
+
+ model:add(mynn.SpatialBatchNormalization(nbChannels))
+ model:add(mynn.ReLU(true))
+
+ local towerCode = createTower(filterSize, nbChannels, nbBlocks)
+ local towerDecode = createTower(filterSize, nbChannels, nbBlocks)
+
+ model:add(towerCode)
+ model:add(towerDecode)
+
+ -- Decode to a single channel, which is the final image
+ model:add(mynn.SpatialConvolution(nbChannels,
+ 1,
+ filterSize, filterSize,
+ 1, 1,
+ (filterSize - 1) / 2, (filterSize - 1) / 2))
+
+ return model
+end
+
+----------------------------------------------------------------------
+
+function fillBatch(data, first, nb, batch, permutation)
+ for k = 1, nb do
+ local i
+ if permutation then
+ i = permutation[first + k - 1]
+ else
+ i = first + k - 1
+ end
+ batch.input[k] = data.input[i]
+ batch.target[k] = data.target[i]
+ end
+end
+
+function trainModel(model,
+ trainData, validationData, nbEpochs, learningRate,
+ learningStateFile)
+
+ local l2criterion = nn.MSECriterion()
+ local batchSize = config.batchSize
+
+ if useGPU then
+ logString('Moving the criterion to the GPU.\n')
+ l2criterion:cuda()
+ end
+
+ local batch = {}
+ batch.input = mynn.FastTensor(batchSize, 2, trainData.height, trainData.width)
+ batch.target = mynn.FastTensor(batchSize, 1, trainData.height, trainData.width)
+
+ local startingEpoch = 1
+
+ if model.epoch then
+ startingEpoch = model.epoch + 1
+ end
+
+ if model.RNGState then
+ torch.setRNGState(model.RNGState)
+ end
+
+ logString('Starting training.\n')
+
+ local parameters, gradParameters = model:getParameters()
+ logString(string.format('model has %d parameters.\n', parameters:storage():size(1)))
+
+ local averageTrainLoss, averageValidationLoss
+ local trainTime, validationTime
+
+ local sgdState = {
+ learningRate = config.learningRate,
+ momentum = config.momentum,
+ learningRateDecay = 0
+ }
+
+ for e = startingEpoch, nbEpochs do
+
+ model:training()
+
+ local permutation = torch.randperm(trainData.nbSamples)
+
+ local accLoss = 0.0
+ local nbBatches = 0
+ local startTime = sys.clock()
+
+ for b = 1, trainData.nbSamples, batchSize do
+
+ fillBatch(trainData, b, batchSize, batch, permutation)
+
+ local opfunc = function(x)
+ -- Surprisingly copy() needs this check
+ if x ~= parameters then
+ parameters:copy(x)
+ end
+
+ local output = model:forward(batch.input)
+ local loss = l2criterion:forward(output, batch.target)
+
+ local dLossdOutput = l2criterion:backward(output, batch.target)
+ gradParameters:zero()
+ model:backward(batch.input, dLossdOutput)
+
+ accLoss = accLoss + loss
+ nbBatches = nbBatches + 1
+
+ return loss, gradParameters
+ end
+
+ optim.sgd(opfunc, parameters, sgdState)
+
+ end
+
+ trainTime = sys.clock() - startTime
+ averageTrainLoss = accLoss / nbBatches
+
+ ----------------------------------------------------------------------
+ -- Validation losses
+ do
+ model:evaluate()
+
+ local accLoss = 0.0
+ local nbBatches = 0
+ local startTime = sys.clock()
+
+ for b = 1, validationData.nbSamples, batchSize do
+ fillBatch(trainData, b, batchSize, batch)
+ local output = model:forward(batch.input)
+ accLoss = accLoss + l2criterion:forward(output, batch.target)
+ nbBatches = nbBatches + 1
+ end
+
+ validationTime = sys.clock() - startTime
+ averageValidationLoss = accLoss / nbBatches;
+ end
+
+ logString(string.format('Epoch train %0.2fs (%0.2fms / sample), validation %0.2fs (%0.2fms / sample).\n',
+ trainTime,
+ 1000 * trainTime / trainData.nbSamples,
+ validationTime,
+ 1000 * validationTime / validationData.nbSamples))
+
+ logString(string.format('LOSS %d %f %f\n', e, averageTrainLoss, averageValidationLoss),
+ colors.green)
+
+ ----------------------------------------------------------------------
+ -- Save a persistent state so that we can restart from there
+
+ if learningStateFile then
+ model.RNGState = torch.getRNGState()
+ model.epoch = e
+ model:clearState()
+ logString('Writing ' .. learningStateFile .. '.\n')
+ torch.save(learningStateFile, model)
+ end
+
+ ----------------------------------------------------------------------
+ -- Save a duplicate of the persistent state from time to time
+
+ if opt.resultFreq > 0 and e%opt.resultFreq == 0 then
+ torch.save(string.format('%s/epoch_%05d_model', opt.resultDir, e), model)
+ saveResultImage(model, trainData)
+ saveResultImage(model, validationData)
+ end
+
+ end
+
+end
+
+function createAndTrainModel(trainData, validationData)
+
+ local model
+
+ local learningStateFile = opt.learningStateFile
+
+ if learningStateFile == '' then
+ learningStateFile = opt.resultDir .. '/learning.state'
+ end
+
+ local gotlearningStateFile
+
+ logString('Using the learning state file ' .. learningStateFile .. '\n')
+
+ if pcall(function () model = torch.load(learningStateFile) end) then
+
+ gotlearningStateFile = true
+
+ else
+
+ model = createModel(config.filterSize, config.nbChannels, config.nbBlocks)
+
+ if useGPU then
+ logString('Moving the model to the GPU.\n')
+ model:cuda()
+ end
+
+ end
+
+ logString(tostring(model) .. '\n')
+
+ if gotlearningStateFile then
+ logString(string.format('Found a learning state with %d epochs finished.\n', model.epoch),
+ colors.red)
+ end
+
+ if opt.exampleInternals > 0 then
+ saveInternalsImage(model, validationData, opt.exampleInternals)
+ os.exit(0)
+ end
+
+ trainModel(model,
+ trainData, validationData,
+ config.nbEpochs, config.learningRate,
+ learningStateFile)
+
+ return model
+
+end
+
+for i, j in pairs(config) do
+ logString('config ' .. i .. ' = \'' .. j ..'\'\n')
+end
+
+local trainData = loadData(1, config.nbTrainSamples, 'train')
+local validationData = loadData(config.nbTrainSamples + 1, config.nbValidationSamples, 'validation')
+local testData = loadData(config.nbTrainSamples + config.nbValidationSamples + 1, config.nbTestSamples, 'test')
+
+local model = createAndTrainModel(trainData, validationData)
+
+saveResultImage(model, trainData)
+saveResultImage(model, validationData)
+saveResultImage(model, testData, nil, testData.nbSamples)