for i, nnma in pairs(node.pred) do
local decoration = ''
if #node.pred > 1 then
+ -- decoration = ' [headlabel=\"' .. i .. '\"]'
decoration = ' [label=\"' .. i .. '\"]'
end
file:write(
end
function DAG:updateGradInput(input, gradOutput)
- assert(self.sorted, 'there has been a DAG structure change before a DAG:updateGradInput')
+ assert(self.sorted, 'There has been a DAG structure change before a DAG:updateGradInput')
self:nestedApply(
function(nnm, go)
+ local node = self.node[nnm]
+ node.gradOutput = go
-- nnm:updateGradInput(self.node[nnm].input, go)
- self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', self.node[nnm].input, go)
+ self:rethrowErrors(nnm, node.index, 'updateGradInput', self.node[nnm].input, go)
end,
self.outputModules, gradOutput
)
function DAG:accGradParameters(input, gradOutput, scale)
scale = scale or 1
- assert(self.sorted, 'there has been a DAG structure change before a DAG:accGradParameters')
+ assert(self.sorted, 'There has been a DAG structure change before a DAG:accGradParameters')
+
+ self:nestedApply(
+ function(nnm, go) self.node[nnm].gradOutput = go end,
+ self.outputModules, gradOutput
+ )
+
+ self:nestedApply(
+ function(nnm, i) self.node[nnm].input = i end,
+ self.inputModules, input
+ )
for k = 1, #self.modules do
local nnm = self.modules[k]
local node = self.node[nnm]
-- nnm:accGradParameters(node.input, node.gradOutput, scale)
- self:rethrowErrors(nnm, k, 'accGradParameters', node.input, self:computeGradOutput(node.gradInputSucc), scale)
+ self:rethrowErrors(nnm, k, 'accGradParameters', node.input, node.gradOutput, scale)
end
end