8 function checkGrad(model, criterion, input, target)
9 local params, gradParams = model:getParameters()
13 local output = model:forward(input)
14 local loss = criterion:forward(output, target)
15 local gradOutput = criterion:backward(output, target)
17 model:backward(input, gradOutput)
18 local analyticalGradParam = gradParams:clone()
20 for i = 1, params:size(1) do
23 params[i] = x - epsilon
24 local output0 = model:forward(input)
25 local loss0 = criterion:forward(output0, target)
27 params[i] = x + epsilon
28 local output1 = model:forward(input)
29 local loss1 = criterion:forward(output1, target)
33 local ana = analyticalGradParam[i]
34 local num = (loss1 - loss0) / (2 * epsilon)
35 local err = torch.abs(num - ana) / torch.abs(num)
38 err .. ' checkGrad ' .. i
39 .. ' analytical ' .. ana
40 .. ' numerical ' .. num
46 function printTensorTable(t)
47 if torch.type(t) == 'table' then
48 for i, t in pairs(t) do
49 print('-- ELEMENT [' .. i .. '] --')
57 -- torch.setnumthreads(params.nbThreads)
58 torch.setdefaulttensortype('torch.DoubleTensor')
64 -- input --> a --> b ---> d ----+ g --> output
77 ----------------------------------------------------------------------
93 input = torch.Tensor(3, 10):uniform()
95 print('******************************************************************')
96 print('** updateOutput **************************************************')
97 print('******************************************************************')
99 output = model:updateOutput(input):clone()
101 printTensorTable(output)
103 print('******************************************************************')
104 print('** updateGradInput ***********************************************')
105 print('******************************************************************')
107 gradInput = model:updateGradInput(input, output)
109 printTensorTable(gradInput)
111 print('******************************************************************')
112 print('** checkGrad *****************************************************')
113 print('******************************************************************')
117 checkGrad(model, nn.MSECriterion(), input, output)