8 function profiler.decor(model, functionsToDecorate)
10 local functionsToDecorate = functionsToDecorate or
16 for _, name in pairs(functionsToDecorate) do
17 model.orig = model.orig or {}
20 if model[name] and not model.orig[name] then
21 model.orig[name] = model[name]
22 model[name] = function(self, ...)
23 local startTime = sys.clock()
24 local result = { self.orig[name](self, unpack({...})) }
25 local endTime = sys.clock()
26 self.timings = self.timings + endTime - startTime
33 if torch.isTypeOf(model, nn.Container) then
34 for _, m in ipairs(model.modules) do
35 profiler.decor(m, functionsToDecorate)
41 function profiler.print(model)
42 print('----------------------------------------------------------------------')
44 print(string.format('TIMING %.02fs', model.timings))
45 if torch.isTypeOf(model, nn.Container) then
46 model:applyToModules(profiler.print)