8 function profiler.decor(model, functionsToDecorate)
10 local functionsToDecorate = functionsToDecorate or
16 for _, name in pairs(functionsToDecorate) do
19 local functionTable = model
21 if not rawget(functionTable, name) then
22 functionTable = getmetatable(model)
25 if functionTable[name] and not (functionTable.orig and functionTable.orig[name]) then
26 print('Profiler decoring ' .. functionTable.__typename .. '.' .. name)
27 functionTable.orig = functionTable.orig or {}
28 functionTable.orig[name] = functionTable[name]
29 functionTable[name] = function(self, ...)
30 local startTime = sys.clock()
31 local result = { self.orig[name](self, unpack({...})) }
32 local endTime = sys.clock()
33 self.timings = self.timings + endTime - startTime
40 if torch.isTypeOf(model, nn.Container) then
41 for _, m in ipairs(model.modules) do
42 profiler.decor(m, functionsToDecorate)
48 function profiler.print(model, nbSamples)
49 print('----------------------------------------------------------------------')
52 print(string.format('acc_time %.02fs (%.1ems/sample)', model.timings, 1000 * model.timings / nbSamples))
54 print(string.format('acc_time %.02fs', model.timings))
57 if torch.isTypeOf(model, nn.Container) then
58 for _, m in ipairs(model.modules) do
59 profiler.print(m, nbSamples)