So, back to decorating the classes and not the objects so that torch.save() does...
[profiler-torch.git] / test-profiler.lua
index a78c944..18677ec 100755 (executable)
@@ -39,9 +39,14 @@ require 'profiler'
 
 -- Create a model
 
+local w, h, fs = 50, 50, 3
+local nhu =  (w - fs + 1) * (h - fs + 1)
+
 local model = nn.Sequential()
    :add(nn.Sequential()
-           :add(nn.Linear(1000, 1000))
+           :add(nn.SpatialConvolution(1, 1, fs, fs))
+           :add(nn.Reshape(nhu))
+           :add(nn.Linear(nhu, 1000))
            :add(nn.ReLU())
        )
    :add(nn.Linear(1000, 100))
@@ -55,7 +60,7 @@ torch.save('model.t7', model)
 
 -- Create the data and criterion
 
-local input = torch.Tensor(1000, 1000)
+local input = torch.Tensor(1000, 1, h, w)
 local target = torch.Tensor(input:size(1), 100)
 local criterion = nn.MSECriterion()
 
@@ -88,7 +93,7 @@ end
 
 -- Print the accumulated timings
 
-profiler.print(model, nbSamples)
+profiler.print(model, nbSamples, modelTime)
 -- profiler.print(model)
 
 print()