- for k = self.sorted, 1, -1 do
- local m = sorted[k]
- if self.succ[d] then
- if #self.succ[d] == 1 then
- d:updateGradInput(self.succ[d][1].gradInput)
- elseif #self.succ[d] > 1 then
- local sum
- for k = 1, #self.succ[d] do
- if sum then
- sum:add(self.succ[d][k].gradInput)
- else
- sum = self.succ[d][k].gradInput:clone()
- end
- end
- d:updateGradInput(sum)
+ self:nestedApply(
+ function(nnm, i) self.node[nnm].input = i end,
+ self.inputModules, input
+ )
+
+ for _, node in pairs(self.node) do
+ node.gradInputSucc = {}
+ end
+
+ for k = #self.sorted, 1, -1 do
+ local nnm = self.sorted[k]
+ local node = self.node[nnm]
+ local pred, gradInputSucc = node.pred, node.gradInputSucc
+
+ if #gradInputSucc > 0 then
+ node.gradOutput = self:computeGradOutput(gradInputSucc)
+ -- nnm:updateGradInput(node.input, node.gradOutput)
+ self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', node.input, node.gradOutput)
+ end
+
+ -- We fill the gradInputSucc of our predecessors
+ if #pred == 1 then
+ table.insert(self.node[pred[1]].gradInputSucc, nnm.gradInput)
+ elseif #pred > 1 then
+ if not torch.type(nnm.gradInput) == 'table' then
+ error('Should have a table gradInput since it has multiple predecessors')
+ end
+ for n = 1, #pred do
+ table.insert(self.node[node.pred[n]].gradInputSucc, nnm.gradInput[n])