From b8c7166b9123735e8226d34b717d3cbc2dc1fa02 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Tue, 6 Dec 2016 09:07:06 +0100 Subject: [PATCH] So, back to decorating the classes and not the objects so that torch.save() does not complain with SpatialConvolution. Added the possibility to pass the total time to profiler.print() so that the fraction of time used by the different functions can be displayed. --- profiler.lua | 32 ++++++++++++++++++++++---------- test-profiler.lua | 11 ++++++++--- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/profiler.lua b/profiler.lua index 34e180b..4e45787 100644 --- a/profiler.lua +++ b/profiler.lua @@ -50,9 +50,18 @@ function profiler.decorate(model, functionsToDecorate) local nameOrig = name .. '__orig' - if model[name] and not model[nameOrig] then - model[nameOrig] = model[name] - model[name] = function(self, ...) + -- We decorate the class and not the object, otherwise we cannot + -- save models anymore. + + if rawget(model, name) then + error('We decorate the class, not the objects, and there is a ' .. name .. ' in ' .. model) + end + + local toDecorate = getmetatable(model) + + if toDecorate[name] and not toDecorate[nameOrig] then + toDecorate[nameOrig] = toDecorate[name] + toDecorate[name] = function(self, ...) local startTime = sys.clock() local result = { self[nameOrig](self, unpack({...})) } local endTime = sys.clock() @@ -71,24 +80,27 @@ function profiler.decorate(model, functionsToDecorate) end -function profiler.print(model, nbSamples, indent) +function profiler.print(model, nbSamples, totalTime, indent) local indent = indent or '' print(string.format('%s* %s', indent, model.__typename)) for l, t in pairs(model.accTime) do - local s + local s = string.format('%s %s %.02fs', indent, l, t) + if totalTime then + s = s .. string.format(' [%.02f%%]', 100 * t / totalTime) + end if nbSamples then - s = string.format(' (%.01fmus/sample)', 1e6 * t / nbSamples) - else - s = '' + s = s .. string.format(' (%.01fmus/sample)', 1e6 * t / nbSamples) end - print(string.format('%s %s %.02fs%s', indent, l, t, s)) + print(s) end + print() + if torch.isTypeOf(model, nn.Container) then for _, m in ipairs(model.modules) do - profiler.print(m, nbSamples, indent .. ' ') + profiler.print(m, nbSamples, totalTime, indent .. ' ') end end end diff --git a/test-profiler.lua b/test-profiler.lua index a78c944..18677ec 100755 --- a/test-profiler.lua +++ b/test-profiler.lua @@ -39,9 +39,14 @@ require 'profiler' -- Create a model +local w, h, fs = 50, 50, 3 +local nhu = (w - fs + 1) * (h - fs + 1) + local model = nn.Sequential() :add(nn.Sequential() - :add(nn.Linear(1000, 1000)) + :add(nn.SpatialConvolution(1, 1, fs, fs)) + :add(nn.Reshape(nhu)) + :add(nn.Linear(nhu, 1000)) :add(nn.ReLU()) ) :add(nn.Linear(1000, 100)) @@ -55,7 +60,7 @@ torch.save('model.t7', model) -- Create the data and criterion -local input = torch.Tensor(1000, 1000) +local input = torch.Tensor(1000, 1, h, w) local target = torch.Tensor(input:size(1), 100) local criterion = nn.MSECriterion() @@ -88,7 +93,7 @@ end -- Print the accumulated timings -profiler.print(model, nbSamples) +profiler.print(model, nbSamples, modelTime) -- profiler.print(model) print() -- 2.39.5