From e50d9b4373f39161df34afb1033c89910963fa47 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Fri, 13 Jan 2017 16:20:49 +0100 Subject: [PATCH] Fixed the initialization in gradOutput in accGradParameters + cosmetics. --- dagnn.lua | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/dagnn.lua b/dagnn.lua index c17347d..0c1d153 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -178,6 +178,7 @@ function DAG:saveDot(filename) for i, nnma in pairs(node.pred) do local decoration = '' if #node.pred > 1 then + -- decoration = ' [headlabel=\"' .. i .. '\"]' decoration = ' [label=\"' .. i .. '\"]' end file:write( @@ -239,12 +240,14 @@ function DAG:updateOutput(input) end function DAG:updateGradInput(input, gradOutput) - assert(self.sorted, 'there has been a DAG structure change before a DAG:updateGradInput') + assert(self.sorted, 'There has been a DAG structure change before a DAG:updateGradInput') self:nestedApply( function(nnm, go) + local node = self.node[nnm] + node.gradOutput = go -- nnm:updateGradInput(self.node[nnm].input, go) - self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', self.node[nnm].input, go) + self:rethrowErrors(nnm, node.index, 'updateGradInput', self.node[nnm].input, go) end, self.outputModules, gradOutput ) @@ -290,12 +293,22 @@ end function DAG:accGradParameters(input, gradOutput, scale) scale = scale or 1 - assert(self.sorted, 'there has been a DAG structure change before a DAG:accGradParameters') + assert(self.sorted, 'There has been a DAG structure change before a DAG:accGradParameters') + + self:nestedApply( + function(nnm, go) self.node[nnm].gradOutput = go end, + self.outputModules, gradOutput + ) + + self:nestedApply( + function(nnm, i) self.node[nnm].input = i end, + self.inputModules, input + ) for k = 1, #self.modules do local nnm = self.modules[k] local node = self.node[nnm] -- nnm:accGradParameters(node.input, node.gradOutput, scale) - self:rethrowErrors(nnm, k, 'accGradParameters', node.input, self:computeGradOutput(node.gradInputSucc), scale) + self:rethrowErrors(nnm, k, 'accGradParameters', node.input, node.gradOutput, scale) end end -- 2.20.1