Remove the clone() for node.gradOutput when possible.
[dagnn.git] / dagnn.lua
index 14cd582..de9d29b 100755 (executable)
--- a/dagnn.lua
+++ b/dagnn.lua
@@ -29,7 +29,7 @@ function DAG:__init()
    self.node = { }
 end
 
--- Apply f on t recursively; use the corresponding element from args
+-- Apply f on t recursively; use the corresponding elements from args
 -- (i.e. same keys) as second parameter to f when available; return
 -- the results from f, organized in a similarly nested table.
 function DAG:nestedApply(f, t, args)
@@ -86,20 +86,20 @@ function DAG:putInOrder()
    for i, a in ipairs(self.sorted) do self.sorted[i] = a.nnm end
 end
 
-function DAG:computeGradOutput(gradInputSucc)
-   local gi
+function DAG:updateGradOutput(node)
+   local gradInputSucc = node.gradInputSucc
    if #gradInputSucc == 1 then
-      gi = gradInputSucc[1] -- we avoid a clone()
+      node.gradOutput = gradInputSucc[1]
    elseif #gradInputSucc > 1 then
-      for k = 1, #gradInputSucc do
-         if gi then
-            gi:add(gradInputSucc[k])
-         else
-            gi = gradInputSucc[k]:clone()
-         end
+      if node.gradOutput then
+         node.gradOutput:resize(gradInputSucc[1]):copy(gradInputSucc[1])
+      else
+         node.gradOutput = gradInputSucc[1]:clone()
+      end
+      for k = 2, #gradInputSucc do
+         node.gradOutput:add(gradInputSucc[k])
       end
    end
-   return gi
 end
 
 ----------------------------------------------------------------------
@@ -167,20 +167,26 @@ function DAG:saveDot(filename)
 
    file:write('\n')
 
-   for nnma, node in pairs(self.node) do
+   for nnmb, node in pairs(self.node) do
       file:write(
          '  '
             .. node.index
-            .. ' [shape=box,label=\"' .. torch.type(nnma) .. '\"]'
+            .. ' [shape=box,label=\"' .. torch.type(nnmb) .. '\"]'
             .. '\n'
       )
 
-      for _, nnmb in pairs(node.succ) do
+      for i, nnma in pairs(node.pred) do
+         local decoration = ''
+         if #node.pred > 1 then
+            -- decoration = ' [headlabel=\"' .. i .. '\"]'
+            decoration = ' [label=\"' .. i .. '\"]'
+         end
          file:write(
             '  '
-               .. node.index
+               .. self.node[nnma].index
                .. ' -> '
                .. self.node[nnmb].index
+               .. decoration
                .. '\n'
          )
       end
@@ -200,7 +206,6 @@ function DAG:updateOutput(input)
    self:nestedApply(
       function(nnm, i)
          self.node[nnm].input = i
-         -- nnm:updateOutput(i)
          self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i)
       end,
       self.inputModules,
@@ -220,7 +225,6 @@ function DAG:updateOutput(input)
             end
          end
          node.input = i
-         -- nnm:updateOutput(i)
          self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i)
       end
    end
@@ -234,12 +238,13 @@ function DAG:updateOutput(input)
 end
 
 function DAG:updateGradInput(input, gradOutput)
-   assert(self.sorted, 'there has been a DAG structure change before a DAG:updateGradInput')
+   assert(self.sorted, 'There has been a DAG structure change before a DAG:updateGradInput')
 
    self:nestedApply(
       function(nnm, go)
-         -- nnm:updateGradInput(self.node[nnm].input, go)
-         self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', self.node[nnm].input, go)
+         local node = self.node[nnm]
+         node.gradOutput = go
+         self:rethrowErrors(nnm, node.index, 'updateGradInput', self.node[nnm].input, go)
       end,
       self.outputModules, gradOutput
    )
@@ -256,11 +261,10 @@ function DAG:updateGradInput(input, gradOutput)
    for k = #self.sorted, 1, -1 do
       local nnm = self.sorted[k]
       local node = self.node[nnm]
-      local pred, gradInputSucc = node.pred, node.gradInputSucc
+      local pred = node.pred
 
-      if #gradInputSucc > 0 then
-         node.gradOutput = self:computeGradOutput(gradInputSucc)
-         -- nnm:updateGradInput(node.input, node.gradOutput)
+      if #node.gradInputSucc > 0 then
+         self:updateGradOutput(node)
          self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', node.input, node.gradOutput)
       end
 
@@ -285,12 +289,21 @@ end
 function DAG:accGradParameters(input, gradOutput, scale)
    scale = scale or 1
 
-   assert(self.sorted, 'there has been a DAG structure change before a DAG:accGradParameters')
+   assert(self.sorted, 'There has been a DAG structure change before a DAG:accGradParameters')
+
+   self:nestedApply(
+      function(nnm, go) self.node[nnm].gradOutput = go end,
+      self.outputModules, gradOutput
+   )
+
+   self:nestedApply(
+      function(nnm, i) self.node[nnm].input = i end,
+      self.inputModules, input
+   )
 
    for k = 1, #self.modules do
       local nnm = self.modules[k]
       local node = self.node[nnm]
-      -- nnm:accGradParameters(node.input, node.gradOutput, scale)
-      self:rethrowErrors(nnm, k, 'accGradParameters', node.input, self:computeGradOutput(node.gradInputSucc), scale)
+      self:rethrowErrors(nnm, k, 'accGradParameters', node.input, node.gradOutput, scale)
    end
 end