require 'nn'
require 'optim'
require 'image'
-require 'pl'
require 'img'
colors = sys.COLORS
function printfc(c, f, ...)
- printf(c .. string.format(f, unpack({...})) .. colors.black)
+ print(c .. string.format(f, unpack({...})) .. colors.black)
end
function logCommand(c)
end
----------------------------------------------------------------------
--- Environment and command line arguments
+-- Environment variables
local defaultNbThreads = 1
local defaultUseGPU = false
defaultNbThreads = os.getenv('TORCH_NB_THREADS')
print('Environment variable TORCH_NB_THREADS is set and equal to ' .. defaultNbThreads)
else
- print('Environment variable TORCH_NB_THREADS is not set')
+ print('Environment variable TORCH_NB_THREADS is not set, default is ' .. defaultNbThreads)
end
if os.getenv('TORCH_USE_GPU') then
defaultUseGPU = os.getenv('TORCH_USE_GPU') == 'yes'
print('Environment variable TORCH_USE_GPU is set and evaluated as ' .. tostring(defaultUseGPU))
else
- print('Environment variable TORCH_USE_GPU is not set.')
+ print('Environment variable TORCH_USE_GPU is not set, default is ' .. tostring(defaultUseGPU))
end
----------------------------------------------------------------------
+-- Command line arguments
local cmd = torch.CmdLine()
-cmd:text('')
cmd:text('General setup')
cmd:option('-seed', 1, 'initial random seed')
cmd:text('Log')
cmd:option('-resultFreq', 100, 'at which epoch frequency should we save result images')
-cmd:option('-exampleInternals', -1, 'should we save inner activation images')
+cmd:option('-exampleInternals', '', 'list of comma-separated indices for inner activation images')
cmd:option('-noLog', false, 'should we prevent logging')
cmd:option('-rundir', '', 'the directory for results')
+cmd:option('-deltaImages', false, 'should we highlight the difference in result images')
+
+cmd:text('')
+cmd:text('Network structure')
+
+cmd:option('-filterSize', 5)
+cmd:option('-nbChannels', 16)
+cmd:option('-nbBlocks', 8)
cmd:text('')
cmd:text('Training')
-cmd:option('-nbEpochs', 1000, 'nb of epochs for the heavy setting')
+cmd:option('-nbEpochs', 2000, 'nb of epochs for the heavy setting')
cmd:option('-learningRate', 0.1, 'learning rate')
cmd:option('-batchSize', 128, 'size of the mini-batches')
-cmd:option('-filterSize', 5, 'convolution filter size')
cmd:option('-nbTrainSamples', 32768)
cmd:option('-nbValidationSamples', 1024)
cmd:option('-nbTestSamples', 1024)
cmd:option('-dataDir', './data/10p-mg', 'data directory')
-cmd:text('')
-cmd:text('Network structure')
-
-cmd:option('-nbChannels', 16)
-cmd:option('-nbBlocks', 8)
-
------------------------------
-- Log and stuff
----------------------------------------------------------------------
+function highlightImage(a, b)
+ if params.deltaImages then
+ local h = torch.csub(a, b):abs()
+ h:div(1/h:max()):mul(0.9):add(0.1)
+ return torch.cmul(a, h)
+ else
+ return a
+ end
+end
+
function saveResultImage(model, data, nbMax)
local criterion = nn.MSECriterion()
-- We use our magical img.lua to create the result images
- local comp = {
- {
- { pad = 1, data.input[n][1] },
- { pad = 1, data.input[n][2] },
- { pad = 1, data.target[n][1] },
- { pad = 1, output[1][1] },
- }
- }
+ local comp
- --[[
- local comp = {
+ comp = {
{
vertical = true,
{ pad = 1, data.input[n][1] },
- { pad = 1, data.input[n][2] }
- },
- torch.Tensor(4, 4):fill(1.0),
- {
- vertical = true,
- { pad = 1, data.target[n][1] },
- { pad = 1, output[1][1] },
- { pad = 1, torch.csub(data.target[n][1], output[1][1]):abs() }
+ { pad = 1, data.input[n][2] },
+ { pad = 1, highlightImage(data.target[n][1], data.input[n][1]) },
+ { pad = 1, highlightImage(output[1][1], data.input[n][1]) },
}
}
- ]]--
-local result = combineImages(1.0, comp)
+ local result = combineImages(1.0, comp)
-result:mul(-1.0):add(1.0)
+ result:mul(-1.0):add(1.0)
-local fileName = string.format('result_%s_%06d.png', data.name, n)
-image.save(params.rundir .. '/' .. fileName, result)
-lossFile:write(string.format('%f %s\n', loss, fileName))
-end
+ local fileName = string.format('result_%s_%06d.png', data.name, n)
+ image.save(params.rundir .. '/' .. fileName, result)
+ lossFile:write(string.format('%f %s\n', loss, fileName))
+ end
end
----------------------------------------------------------------------
end
end
-function trainModel(model, trainData, validationData)
+function trainModel(model, trainSet, validationSet)
local criterion = nn.MSECriterion()
local batchSize = params.batchSize
local batch = {}
- batch.input = mynn.FastTensor(batchSize, 2, trainData.height, trainData.width)
- batch.target = mynn.FastTensor(batchSize, 1, trainData.height, trainData.width)
+ batch.input = mynn.FastTensor(batchSize, 2, trainSet.height, trainSet.width)
+ batch.target = mynn.FastTensor(batchSize, 1, trainSet.height, trainSet.width)
local startingEpoch = 1
end
if model.RNGState then
+ printfc(colors.red, 'Using the RNG state from the loaded model.')
torch.setRNGState(model.RNGState)
end
model:training()
- local permutation = torch.randperm(trainData.nbSamples)
+ local permutation = torch.randperm(trainSet.nbSamples)
local accLoss = 0.0
local nbBatches = 0
local startTime = sys.clock()
- for b = 1, trainData.nbSamples, batchSize do
+ for b = 1, trainSet.nbSamples, batchSize do
- fillBatch(trainData, b, batch, permutation)
+ fillBatch(trainSet, b, batch, permutation)
local opfunc = function(x)
-- Surprisingly, copy() needs this check
local nbBatches = 0
local startTime = sys.clock()
- for b = 1, validationData.nbSamples, batchSize do
- fillBatch(validationData, b, batch)
+ for b = 1, validationSet.nbSamples, batchSize do
+ fillBatch(validationSet, b, batch)
local output = model:forward(batch.input)
accLoss = accLoss + criterion:forward(output, batch.target)
nbBatches = nbBatches + 1
averageValidationLoss = accLoss / nbBatches;
end
- printf('Epoch train %0.2fs (%0.2fms / sample), validation %0.2fs (%0.2fms / sample).',
- trainTime,
- 1000 * trainTime / trainData.nbSamples,
- validationTime,
- 1000 * validationTime / validationData.nbSamples)
+ ----------------------------------------------------------------------
+
+ printfc(colors.green,
+
+ 'epoch %d acc_train_loss %f validation_loss %f [train %.02fs total %.02fms / sample, validation %.02fs total %.02fms / sample]',
+
+ e,
- printfc(colors.green, 'LOSS %d %f %f', e, averageTrainLoss, averageValidationLoss)
+ averageTrainLoss,
+
+ averageValidationLoss,
+
+ trainTime,
+ 1000 * trainTime / trainSet.nbSamples,
+
+ validationTime,
+ 1000 * validationTime / validationSet.nbSamples
+ )
----------------------------------------------------------------------
-- Save a persistent state so that we can restart from there
if params.resultFreq > 0 and e%params.resultFreq == 0 then
torch.save(string.format('%s/model_%04d.t7', params.rundir, e), model)
- saveResultImage(model, trainData)
- saveResultImage(model, validationData)
+ saveResultImage(model, trainSet)
+ saveResultImage(model, validationSet)
end
end
end
-function createAndTrainModel(trainData, validationData)
+function createAndTrainModel(trainSet, validationSet)
-- Load the current training state, or create a new model from
-- scratch
if pcall(function () model = torch.load(params.rundir .. '/model_last.t7') end) then
printfc(colors.red,
- 'Found a learning state with %d epochs finished, starting from there.',
+ 'Found a model with %d epochs completed, starting from there.',
model.epoch)
- if params.exampleInternals > 0 then
- saveInternalsImage(model, validationData, params.exampleInternals)
+ if params.exampleInternals ~= '' then
+ for _, i in ipairs(string.split(params.exampleInternals, ',')) do
+ saveInternalsImage(model, validationSet, tonumber(i))
+ end
os.exit(0)
end
else
- model = createModel(trainData.width, trainData.height,
+ model = createModel(trainSet.width, trainSet.height,
params.filterSize, params.nbChannels,
params.nbBlocks)
end
- trainModel(model, trainData, validationData)
+ trainModel(model, trainSet, validationSet)
return model
logCommand(c)
end
-local trainData = loadData(1,
- params.nbTrainSamples, 'train')
+local trainSet = loadData(1,
+ params.nbTrainSamples, 'train')
-local validationData = loadData(params.nbTrainSamples + 1,
- params.nbValidationSamples, 'validation')
+local validationSet = loadData(params.nbTrainSamples + 1,
+ params.nbValidationSamples, 'validation')
-local model = createAndTrainModel(trainData, validationData)
+local model = createAndTrainModel(trainSet, validationSet)
----------------------------------------------------------------------
-- Test
-local testData = loadData(params.nbTrainSamples + params.nbValidationSamples + 1,
- params.nbTestSamples, 'test')
+local testSet = loadData(params.nbTrainSamples + params.nbValidationSamples + 1,
+ params.nbTestSamples, 'test')
if params.useGPU then
print('Moving the model and criterion to the GPU.')
model:cuda()
end
-saveResultImage(model, trainData)
-saveResultImage(model, validationData)
-saveResultImage(model, testData, 1024)
+saveResultImage(model, trainSet)
+saveResultImage(model, validationSet)
+saveResultImage(model, testSet, 1024)