Removed the signature.
[profiler-torch.git] / test-profiler.lua
index 44bbee1..92f186e 100755 (executable)
@@ -1,18 +1,50 @@
 #!/usr/bin/env luajit
 
+--[[
+
+   Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/
+   Written by Francois Fleuret <francois.fleuret@idiap.ch>
+
+   This file is free software: you can redistribute it and/or modify
+   it under the terms of the GNU General Public License version 3 as
+   published by the Free Software Foundation.
+
+   It is distributed in the hope that it will be useful, but WITHOUT
+   ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
+   or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public
+   License for more details.
+
+   You should have received a copy of the GNU General Public License
+   along with this file.  If not, see <http://www.gnu.org/licenses/>.
+
+]]--
+
 require 'torch'
 require 'nn'
 
 require 'profiler'
 
+-- Create a model
+
+local w, h, fs = 50, 50, 3
+local nhu =  (w - fs + 1) * (h - fs + 1)
+
 local model = nn.Sequential()
-model:add(nn.Linear(1000, 1000))
-model:add(nn.ReLU())
-model:add(nn.Linear(1000, 100))
+   :add(nn.Sequential()
+           :add(nn.SpatialConvolution(1, 1, fs, fs))
+           :add(nn.Reshape(nhu))
+           :add(nn.Linear(nhu, 1000))
+           :add(nn.ReLU())
+       )
+   :add(nn.Linear(1000, 100))
+
+-- Decorate it for profiling
+
+profiler.decorate(model)
 
-profiler.decor(model)
+-- Create the data and criterion
 
-local input = torch.Tensor(1000, 1000)
+local input = torch.Tensor(1000, 1, h, w)
 local target = torch.Tensor(input:size(1), 100)
 local criterion = nn.MSECriterion()
 
@@ -20,8 +52,11 @@ local nbSamples = 0
 local modelTime = 0
 local dataTime = 0
 
+-- Loop five times through the data forward and backward
+
 for k = 1, 5 do
    local t1 = sys.clock()
+
    input:uniform(-1, 1)
    target:uniform()
 
@@ -40,8 +75,12 @@ for k = 1, 5 do
    nbSamples = nbSamples + input:size(1)
 end
 
+-- Print the accumulated timings
+
+print()
+-- profiler.color = false
 profiler.print(model, nbSamples)
+-- profiler.print(model)
 
-print('----------------------------------------------------------------------')
 print(string.format('Total model time %.02fs', modelTime))
 print(string.format('Total data time %.02fs', dataTime))