for i, a in ipairs(self.sorted) do self.sorted[i] = a.nnm end
end
-function DAG:computeGradOutput(gradInputSucc)
- local gi
+function DAG:updateGradOutput(node)
+ local gradInputSucc = node.gradInputSucc
if #gradInputSucc == 1 then
- gi = gradInputSucc[1] -- we avoid a clone()
+ node.gradOutput = gradInputSucc[1]
elseif #gradInputSucc > 1 then
- for k = 1, #gradInputSucc do
- if gi then
- gi:add(gradInputSucc[k])
- else
- gi = gradInputSucc[k]:clone()
- end
+ if node.gradOutput then
+ node.gradOutput:resize(gradInputSucc[1]):copy(gradInputSucc[1])
+ else
+ node.gradOutput = gradInputSucc[1]:clone()
+ end
+ for k = 2, #gradInputSucc do
+ node.gradOutput:add(gradInputSucc[k])
end
end
- return gi
end
----------------------------------------------------------------------
self:nestedApply(
function(nnm, i)
self.node[nnm].input = i
- -- nnm:updateOutput(i)
self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i)
end,
self.inputModules,
end
end
node.input = i
- -- nnm:updateOutput(i)
self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i)
end
end
function(nnm, go)
local node = self.node[nnm]
node.gradOutput = go
- -- nnm:updateGradInput(self.node[nnm].input, go)
self:rethrowErrors(nnm, node.index, 'updateGradInput', self.node[nnm].input, go)
end,
self.outputModules, gradOutput
for k = #self.sorted, 1, -1 do
local nnm = self.sorted[k]
local node = self.node[nnm]
- local pred, gradInputSucc = node.pred, node.gradInputSucc
+ local pred = node.pred
- if #gradInputSucc > 0 then
- node.gradOutput = self:computeGradOutput(gradInputSucc)
- -- nnm:updateGradInput(node.input, node.gradOutput)
+ if #node.gradInputSucc > 0 then
+ self:updateGradOutput(node)
self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', node.input, node.gradOutput)
end
for k = 1, #self.modules do
local nnm = self.modules[k]
local node = self.node[nnm]
- -- nnm:accGradParameters(node.input, node.gradOutput, scale)
self:rethrowErrors(nnm, k, 'accGradParameters', node.input, node.gradOutput, scale)
end
end