mynn = {}
--- By default, mynn returns the entries from nn
+-- To deal elegantly with CPU/GPU
local mt = {}
function mt.__index(table, key)
- return nn[key]
+ return (cudnn and cudnn[key]) or (cunn and cunn[key]) or nn[key]
end
setmetatable(mynn, mt)
require 'cutorch'
require 'cunn'
require 'cudnn'
+
mynn.FastTensor = torch.CudaTensor
- mynn.SpatialConvolution = cudnn.SpatialConvolution
+
+ if cudnn then
+ cudnn.benchmark = true
+ cudnn.fastest = true
+ end
end
----------------------------------------------------------------------