From: Francois Fleuret Date: Fri, 13 Jan 2017 15:20:49 +0000 (+0100) Subject: Fixed the initialization in gradOutput in accGradParameters + cosmetics. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=e50d9b4373f39161df34afb1033c89910963fa47;p=dagnn.git Fixed the initialization in gradOutput in accGradParameters + cosmetics. --- diff --git a/dagnn.lua b/dagnn.lua index c17347d..0c1d153 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -178,6 +178,7 @@ function DAG:saveDot(filename) for i, nnma in pairs(node.pred) do local decoration = '' if #node.pred > 1 then + -- decoration = ' [headlabel=\"' .. i .. '\"]' decoration = ' [label=\"' .. i .. '\"]' end file:write( @@ -239,12 +240,14 @@ function DAG:updateOutput(input) end function DAG:updateGradInput(input, gradOutput) - assert(self.sorted, 'there has been a DAG structure change before a DAG:updateGradInput') + assert(self.sorted, 'There has been a DAG structure change before a DAG:updateGradInput') self:nestedApply( function(nnm, go) + local node = self.node[nnm] + node.gradOutput = go -- nnm:updateGradInput(self.node[nnm].input, go) - self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', self.node[nnm].input, go) + self:rethrowErrors(nnm, node.index, 'updateGradInput', self.node[nnm].input, go) end, self.outputModules, gradOutput ) @@ -290,12 +293,22 @@ end function DAG:accGradParameters(input, gradOutput, scale) scale = scale or 1 - assert(self.sorted, 'there has been a DAG structure change before a DAG:accGradParameters') + assert(self.sorted, 'There has been a DAG structure change before a DAG:accGradParameters') + + self:nestedApply( + function(nnm, go) self.node[nnm].gradOutput = go end, + self.outputModules, gradOutput + ) + + self:nestedApply( + function(nnm, i) self.node[nnm].input = i end, + self.inputModules, input + ) for k = 1, #self.modules do local nnm = self.modules[k] local node = self.node[nnm] -- nnm:accGradParameters(node.input, node.gradOutput, scale) - self:rethrowErrors(nnm, k, 'accGradParameters', node.input, self:computeGradOutput(node.gradInputSucc), scale) + self:rethrowErrors(nnm, k, 'accGradParameters', node.input, node.gradOutput, scale) end end