mynn = {}
--- By default, mynn returns the entries from nn
+-- To deal elegantly with CPU/GPU
local mt = {}
function mt.__index(table, key)
- return nn[key]
+ return (cudnn and cudnn[key]) or (cunn and cunn[key]) or nn[key]
end
setmetatable(mynn, mt)
require 'cutorch'
require 'cunn'
require 'cudnn'
+
mynn.FastTensor = torch.CudaTensor
- mynn.SpatialConvolution = cudnn.SpatialConvolution
+
+ if cudnn then
+ cudnn.benchmark = true
+ cudnn.fastest = true
+ end
end
----------------------------------------------------------------------
local startTime = sys.clock()
for b = 1, validationData.nbSamples, batchSize do
- fillBatch(trainData, b, batchSize, batch)
+ fillBatch(validationData, b, batchSize, batch)
local output = model:forward(batch.input)
accLoss = accLoss + l2criterion:forward(output, batch.target)
nbBatches = nbBatches + 1