X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=profiler-torch.git;a=blobdiff_plain;f=test-profiler.lua;h=92f186e5cd8be1170916c8ac60b4050353f65bf9;hp=44bbee1656ee13d1b3634ef52c71c5afdf22fbf6;hb=HEAD;hpb=faf424f71ac259f3a7e676113cfe4892084b1c93 diff --git a/test-profiler.lua b/test-profiler.lua index 44bbee1..92f186e 100755 --- a/test-profiler.lua +++ b/test-profiler.lua @@ -1,18 +1,50 @@ #!/usr/bin/env luajit +--[[ + + Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ + Written by Francois Fleuret + + This file is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License version 3 as + published by the Free Software Foundation. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY + or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public + License for more details. + + You should have received a copy of the GNU General Public License + along with this file. If not, see . + +]]-- + require 'torch' require 'nn' require 'profiler' +-- Create a model + +local w, h, fs = 50, 50, 3 +local nhu = (w - fs + 1) * (h - fs + 1) + local model = nn.Sequential() -model:add(nn.Linear(1000, 1000)) -model:add(nn.ReLU()) -model:add(nn.Linear(1000, 100)) + :add(nn.Sequential() + :add(nn.SpatialConvolution(1, 1, fs, fs)) + :add(nn.Reshape(nhu)) + :add(nn.Linear(nhu, 1000)) + :add(nn.ReLU()) + ) + :add(nn.Linear(1000, 100)) + +-- Decorate it for profiling + +profiler.decorate(model) -profiler.decor(model) +-- Create the data and criterion -local input = torch.Tensor(1000, 1000) +local input = torch.Tensor(1000, 1, h, w) local target = torch.Tensor(input:size(1), 100) local criterion = nn.MSECriterion() @@ -20,8 +52,11 @@ local nbSamples = 0 local modelTime = 0 local dataTime = 0 +-- Loop five times through the data forward and backward + for k = 1, 5 do local t1 = sys.clock() + input:uniform(-1, 1) target:uniform() @@ -40,8 +75,12 @@ for k = 1, 5 do nbSamples = nbSamples + input:size(1) end +-- Print the accumulated timings + +print() +-- profiler.color = false profiler.print(model, nbSamples) +-- profiler.print(model) -print('----------------------------------------------------------------------') print(string.format('Total model time %.02fs', modelTime)) print(string.format('Total data time %.02fs', dataTime))