X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=profiler-torch.git;a=blobdiff_plain;f=profiler.lua;fp=profiler.lua;h=3bbccaf88ba65e400b26f625ccb971d41bcbde94;hp=4e4578741c457c1af5aa0d7a99ceb7b2a4ce7dfc;hb=1bc832ac69797a2fabdb4f1dcf758cb415e4f215;hpb=e927faab65fb190dc01959236c07df46f3d28946 diff --git a/profiler.lua b/profiler.lua index 4e45787..3bbccaf 100644 --- a/profiler.lua +++ b/profiler.lua @@ -82,8 +82,15 @@ end function profiler.print(model, nbSamples, totalTime, indent) local indent = indent or '' + local hint - print(string.format('%s* %s', indent, model.__typename)) + if torch.isTypeOf(model, nn.Container) then + hint = ' ' + else + hint = '*' + end + + print(string.format('%s%s %s', indent, hint, model.__typename)) for l, t in pairs(model.accTime) do local s = string.format('%s %s %.02fs', indent, l, t)