X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=dagnn.lua;h=0c1d15303f42ce6e8ad439b787bffe29664511c3;hb=e50d9b4373f39161df34afb1033c89910963fa47;hp=14cd5821c2064ac74b50dde101449d10d8f2274e;hpb=fe54a7c5c8425ee9783d82e16a42924e23add457;p=dagnn.git diff --git a/dagnn.lua b/dagnn.lua index 14cd582..0c1d153 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -29,7 +29,7 @@ function DAG:__init() self.node = { } end --- Apply f on t recursively; use the corresponding element from args +-- Apply f on t recursively; use the corresponding elements from args -- (i.e. same keys) as second parameter to f when available; return -- the results from f, organized in a similarly nested table. function DAG:nestedApply(f, t, args) @@ -167,20 +167,26 @@ function DAG:saveDot(filename) file:write('\n') - for nnma, node in pairs(self.node) do + for nnmb, node in pairs(self.node) do file:write( ' ' .. node.index - .. ' [shape=box,label=\"' .. torch.type(nnma) .. '\"]' + .. ' [shape=box,label=\"' .. torch.type(nnmb) .. '\"]' .. '\n' ) - for _, nnmb in pairs(node.succ) do + for i, nnma in pairs(node.pred) do + local decoration = '' + if #node.pred > 1 then + -- decoration = ' [headlabel=\"' .. i .. '\"]' + decoration = ' [label=\"' .. i .. '\"]' + end file:write( ' ' - .. node.index + .. self.node[nnma].index .. ' -> ' .. self.node[nnmb].index + .. decoration .. '\n' ) end @@ -234,12 +240,14 @@ function DAG:updateOutput(input) end function DAG:updateGradInput(input, gradOutput) - assert(self.sorted, 'there has been a DAG structure change before a DAG:updateGradInput') + assert(self.sorted, 'There has been a DAG structure change before a DAG:updateGradInput') self:nestedApply( function(nnm, go) + local node = self.node[nnm] + node.gradOutput = go -- nnm:updateGradInput(self.node[nnm].input, go) - self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', self.node[nnm].input, go) + self:rethrowErrors(nnm, node.index, 'updateGradInput', self.node[nnm].input, go) end, self.outputModules, gradOutput ) @@ -285,12 +293,22 @@ end function DAG:accGradParameters(input, gradOutput, scale) scale = scale or 1 - assert(self.sorted, 'there has been a DAG structure change before a DAG:accGradParameters') + assert(self.sorted, 'There has been a DAG structure change before a DAG:accGradParameters') + + self:nestedApply( + function(nnm, go) self.node[nnm].gradOutput = go end, + self.outputModules, gradOutput + ) + + self:nestedApply( + function(nnm, i) self.node[nnm].input = i end, + self.inputModules, input + ) for k = 1, #self.modules do local nnm = self.modules[k] local node = self.node[nnm] -- nnm:accGradParameters(node.input, node.gradOutput, scale) - self:rethrowErrors(nnm, k, 'accGradParameters', node.input, self:computeGradOutput(node.gradInputSucc), scale) + self:rethrowErrors(nnm, k, 'accGradParameters', node.input, node.gradOutput, scale) end end