Use cudnn more efficiently.
[dyncnn.git] / dyncnn.lua
index 839431a..e104386 100755 (executable)
@@ -126,10 +126,10 @@ torch.manualSeed(opt.seed)
 
 mynn = {}
 
--- By default, mynn returns the entries from nn
+-- To deal elegantly with CPU/GPU
 local mt = {}
 function mt.__index(table, key)
-   return nn[key]
+   return (cudnn and cudnn[key]) or (cunn and cunn[key]) or nn[key]
 end
 setmetatable(mynn, mt)
 
@@ -144,8 +144,13 @@ if useGPU then
    require 'cutorch'
    require 'cunn'
    require 'cudnn'
+
    mynn.FastTensor = torch.CudaTensor
-   mynn.SpatialConvolution = cudnn.SpatialConvolution
+
+   if cudnn then
+      cudnn.benchmark = true
+      cudnn.fastest = true
+   end
 end
 
 ----------------------------------------------------------------------