From: Francois Fleuret Date: Sat, 3 Dec 2016 16:04:34 +0000 (+0100) Subject: Initial commit X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=21c3bd2eb990e3fa58aa36b0e8fcd8901de5569c;p=profiler-torch.git Initial commit --- 21c3bd2eb990e3fa58aa36b0e8fcd8901de5569c diff --git a/profiler.lua b/profiler.lua new file mode 100755 index 0000000..91b0915 --- /dev/null +++ b/profiler.lua @@ -0,0 +1,50 @@ + +require 'torch' +require 'nn' +require 'sys' + +profiler = {} + +function profiler.decor(model, functionsToDecorate) + + local functionsToDecorate = functionsToDecorate or + { + 'updateOutput', + 'backward' + } + + for _, name in pairs(functionsToDecorate) do + model.orig = model.orig or {} + model.timings = 0 + + if model[name] and not model.orig[name] then + model.orig[name] = model[name] + model[name] = function(self, ...) + local startTime = sys.clock() + local result = { self.orig[name](self, unpack({...})) } + local endTime = sys.clock() + self.timings = self.timings + endTime - startTime + return unpack(result) + end + end + + end + + if torch.isTypeOf(model, nn.Container) then + for _, m in ipairs(model.modules) do + profiler.decor(m, functionsToDecorate) + end + end + +end + +function profiler.print(model) + print('----------------------------------------------------------------------') + print(model) + print(string.format('TIMING %.02fs', model.timings)) + if torch.isTypeOf(model, nn.Container) then + model:applyToModules(profiler.print) + end +end + +return profiler diff --git a/test-profiler.lua b/test-profiler.lua new file mode 100755 index 0000000..b394a33 --- /dev/null +++ b/test-profiler.lua @@ -0,0 +1,25 @@ +#!/usr/bin/env luajit + +require 'torch' +require 'nn' + +require 'profiler' + +local model = nn.Sequential() +model:add(nn.Linear(1000, 1000)) +model:add(nn.ReLU()) +model:add(nn.Linear(1000, 100)) + +profiler.decor(model) + +for k = 1, 10 do + local input = torch.Tensor(1000, 1000):uniform(-1, 1) + local target = torch.Tensor(input:size(1), 100):uniform() + local criterion = nn.MSECriterion() + local output = model:forward(input) + local loss = criterion:forward(output, target) + local dloss = criterion:backward(output, target) + model:backward(input, dloss) +end + +profiler.print(model)