Prints only the max error.
authorFrancois Fleuret <francois@fleuret.org>
Thu, 12 Jan 2017 21:19:02 +0000 (22:19 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Thu, 12 Jan 2017 21:19:02 +0000 (22:19 +0100)
test-dagnn.lua

index a41d880..3dea310 100755 (executable)
@@ -39,6 +39,8 @@ function checkGrad(model, criterion, input, target)
    model:backward(input, gradOutput)
    local analyticalGradParam = gradParams:clone()
 
+   local err = 0
+
    for i = 1, params:size(1) do
       local x = params[i]
 
@@ -54,23 +56,13 @@ function checkGrad(model, criterion, input, target)
 
       local ana = analyticalGradParam[i]
       local num = (loss1 - loss0) / (2 * epsilon)
-      local err
 
-      if num == ana then
-         err = 0
-      else
-         err = torch.abs(num - ana) / torch.abs(num)
+      if num ~= ana then
+         err = math.max(err, torch.abs(num - ana) / torch.abs(num))
       end
-
-      print(
-         'CHECK '
-            .. err
-            .. ' checkGrad ' .. i
-            .. ' analytical ' .. ana
-            .. ' numerical ' .. num
-      )
    end
 
+   return err
 end
 
 function printTensorTable(t)
@@ -115,4 +107,4 @@ local output = model:updateOutput(input):clone()
 
 output:uniform()
 
-checkGrad(model, nn.MSECriterion(), input, output)
+print('Error = ' .. checkGrad(model, nn.MSECriterion(), input, output))