Fixed the initialization in gradOutput in accGradParameters + cosmetics.
authorFrancois Fleuret <francois@fleuret.org>
Fri, 13 Jan 2017 15:20:49 +0000 (16:20 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Fri, 13 Jan 2017 15:20:49 +0000 (16:20 +0100)
dagnn.lua

index c17347d..0c1d153 100755 (executable)
--- a/dagnn.lua
+++ b/dagnn.lua
@@ -178,6 +178,7 @@ function DAG:saveDot(filename)
       for i, nnma in pairs(node.pred) do
          local decoration = ''
          if #node.pred > 1 then
+            -- decoration = ' [headlabel=\"' .. i .. '\"]'
             decoration = ' [label=\"' .. i .. '\"]'
          end
          file:write(
@@ -239,12 +240,14 @@ function DAG:updateOutput(input)
 end
 
 function DAG:updateGradInput(input, gradOutput)
-   assert(self.sorted, 'there has been a DAG structure change before a DAG:updateGradInput')
+   assert(self.sorted, 'There has been a DAG structure change before a DAG:updateGradInput')
 
    self:nestedApply(
       function(nnm, go)
+         local node = self.node[nnm]
+         node.gradOutput = go
          -- nnm:updateGradInput(self.node[nnm].input, go)
-         self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', self.node[nnm].input, go)
+         self:rethrowErrors(nnm, node.index, 'updateGradInput', self.node[nnm].input, go)
       end,
       self.outputModules, gradOutput
    )
@@ -290,12 +293,22 @@ end
 function DAG:accGradParameters(input, gradOutput, scale)
    scale = scale or 1
 
-   assert(self.sorted, 'there has been a DAG structure change before a DAG:accGradParameters')
+   assert(self.sorted, 'There has been a DAG structure change before a DAG:accGradParameters')
+
+   self:nestedApply(
+      function(nnm, go) self.node[nnm].gradOutput = go end,
+      self.outputModules, gradOutput
+   )
+
+   self:nestedApply(
+      function(nnm, i) self.node[nnm].input = i end,
+      self.inputModules, input
+   )
 
    for k = 1, #self.modules do
       local nnm = self.modules[k]
       local node = self.node[nnm]
       -- nnm:accGradParameters(node.input, node.gradOutput, scale)
-      self:rethrowErrors(nnm, k, 'accGradParameters', node.input, self:computeGradOutput(node.gradInputSucc), scale)
+      self:rethrowErrors(nnm, k, 'accGradParameters', node.input, node.gradOutput, scale)
    end
 end