X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=dyncnn.git;a=blobdiff_plain;f=dyncnn.lua;h=e1043868c3c432849059f527c145131b5041a336;hp=839431ab40ccdd33d277c2927295ef40016c0ef2;hb=be0c7d53f21ce96c70e7c13ef0ba2c9eca10ca23;hpb=51142591389a8119a337813899ae26d682e9f13d diff --git a/dyncnn.lua b/dyncnn.lua index 839431a..e104386 100755 --- a/dyncnn.lua +++ b/dyncnn.lua @@ -126,10 +126,10 @@ torch.manualSeed(opt.seed) mynn = {} --- By default, mynn returns the entries from nn +-- To deal elegantly with CPU/GPU local mt = {} function mt.__index(table, key) - return nn[key] + return (cudnn and cudnn[key]) or (cunn and cunn[key]) or nn[key] end setmetatable(mynn, mt) @@ -144,8 +144,13 @@ if useGPU then require 'cutorch' require 'cunn' require 'cudnn' + mynn.FastTensor = torch.CudaTensor - mynn.SpatialConvolution = cudnn.SpatialConvolution + + if cudnn then + cudnn.benchmark = true + cudnn.fastest = true + end end ----------------------------------------------------------------------