return self.output
end
+function DAG:computeGradInput(gradInputSucc)
+ local gi
+ if #gradInputSucc == 1 then
+ gi = gradInputSucc[1] -- we avoid a clone()
+ elseif #gradInputSucc > 1 then
+ for k = 1, #gradInputSucc do
+ if gi then
+ gi:add(gradInputSucc[k])
+ else
+ gi = gradInputSucc[k]:clone()
+ end
+ end
+ end
+ return gi
+end
+
function DAG:updateGradInput(input, gradOutput)
self:putInOrder()
self.outputModules, gradOutput
)
+ self:nestApply(
+ function(nnm, i) self.node[nnm].input = i end,
+ self.inputModules, input
+ )
+
for _, node in pairs(self.node) do
node.gradInputSucc = {}
end
for k = #self.sorted, 1, -1 do
local nnm = self.sorted[k]
local node = self.node[nnm]
- local pred, succ, gradInputSucc = node.pred, node.succ, node.gradInputSucc
+ local pred, gradInputSucc = node.pred, node.gradInputSucc
if #gradInputSucc > 0 then
- -- We update nnm:gradInput
- local gi
- if #gradInputSucc == 1 then
- gi = gradInputSucc[1] -- we avoid a clone()
- elseif #gradInputSucc > 1 then
- for k = 1, #gradInputSucc do
- if gi then
- gi:add(gradInputSucc[k])
- else
- gi = gradInputSucc[k]:clone()
- end
- end
- end
- nnm:updateGradInput(node.input, gi)
+ nnm:updateGradInput(node.input, self:computeGradInput(gradInputSucc))
end
-- We fill the gradInputSucc of our predecessors
print('******************************************************************')
print('** updateGradInput ***********************************************')
print('******************************************************************')
-gradInput = g:updateGradInput({ input }, output)
+gradInput = g:updateGradInput({{input}}, output)
printTensorTable(gradInput)