for i, a in ipairs(self.sorted) do self.sorted[i] = a.nnm end
end
-function DAG:computeGradOutput(gradInputSucc)
- local gi
+function DAG:updateGradOutput(node)
+ local gradInputSucc = node.gradInputSucc
if #gradInputSucc == 1 then
- gi = gradInputSucc[1] -- we avoid a clone()
+ node.gradOutput = gradInputSucc[1]
elseif #gradInputSucc > 1 then
- for k = 1, #gradInputSucc do
- if gi then
- gi:add(gradInputSucc[k])
- else
- gi = gradInputSucc[k]:clone()
- end
+ if node.gradOutput then
+ node.gradOutput:resize(gradInputSucc[1]):copy(gradInputSucc[1])
+ else
+ node.gradOutput = gradInputSucc[1]:clone()
+ end
+ for k = 2, #gradInputSucc do
+ node.gradOutput:add(gradInputSucc[k])
end
end
- return gi
end
----------------------------------------------------------------------
for i, nnma in pairs(node.pred) do
local decoration = ''
if #node.pred > 1 then
+ -- decoration = ' [headlabel=\"' .. i .. '\"]'
decoration = ' [label=\"' .. i .. '\"]'
end
file:write(
self:nestedApply(
function(nnm, i)
self.node[nnm].input = i
- -- nnm:updateOutput(i)
self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i)
end,
self.inputModules,
end
end
node.input = i
- -- nnm:updateOutput(i)
self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i)
end
end
end
function DAG:updateGradInput(input, gradOutput)
- assert(self.sorted, 'there has been a DAG structure change before a DAG:updateGradInput')
+ assert(self.sorted, 'There has been a DAG structure change before a DAG:updateGradInput')
self:nestedApply(
function(nnm, go)
- -- nnm:updateGradInput(self.node[nnm].input, go)
- self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', self.node[nnm].input, go)
+ local node = self.node[nnm]
+ node.gradOutput = go
+ self:rethrowErrors(nnm, node.index, 'updateGradInput', self.node[nnm].input, go)
end,
self.outputModules, gradOutput
)
for k = #self.sorted, 1, -1 do
local nnm = self.sorted[k]
local node = self.node[nnm]
- local pred, gradInputSucc = node.pred, node.gradInputSucc
+ local pred = node.pred
- if #gradInputSucc > 0 then
- node.gradOutput = self:computeGradOutput(gradInputSucc)
- -- nnm:updateGradInput(node.input, node.gradOutput)
+ if #node.gradInputSucc > 0 then
+ self:updateGradOutput(node)
self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', node.input, node.gradOutput)
end
function DAG:accGradParameters(input, gradOutput, scale)
scale = scale or 1
- assert(self.sorted, 'there has been a DAG structure change before a DAG:accGradParameters')
+ assert(self.sorted, 'There has been a DAG structure change before a DAG:accGradParameters')
+
+ self:nestedApply(
+ function(nnm, go) self.node[nnm].gradOutput = go end,
+ self.outputModules, gradOutput
+ )
+
+ self:nestedApply(
+ function(nnm, i) self.node[nnm].input = i end,
+ self.inputModules, input
+ )
for k = 1, #self.modules do
local nnm = self.modules[k]
local node = self.node[nnm]
- -- nnm:accGradParameters(node.input, node.gradOutput, scale)
- self:rethrowErrors(nnm, k, 'accGradParameters', node.input, self:computeGradOutput(node.gradInputSucc), scale)
+ self:rethrowErrors(nnm, k, 'accGradParameters', node.input, node.gradOutput, scale)
end
end