From: Francois Fleuret <francois@fleuret.org>
Date: Wed, 11 Jan 2017 07:12:22 +0000 (+0100)
Subject: Update.
X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=452781856eafd237579e5c90b6e345354df91b42;p=dagnn.git

Update.
---

diff --git a/dagnn.lua b/dagnn.lua
index 52913ad..4841843 100755
--- a/dagnn.lua
+++ b/dagnn.lua
@@ -11,6 +11,7 @@ function DAG:__init()
 end
 
 function DAG:addEdge(a, b)
+   self.sorted = nil
    local pred, succ = self.pred, self.succ
    if not pred[a] and not succ[a] then
       self:add(a)
@@ -25,6 +26,7 @@ function DAG:addEdge(a, b)
 end
 
 function DAG:setInput(i)
+   self.sorted = nil
    if torch.type(i) == 'table' then
       self.inputModules = i
       for _, m in ipairs(i) do
@@ -38,6 +40,7 @@ function DAG:setInput(i)
 end
 
 function DAG:setOutput(o)
+   self.sorted = nil
    if torch.type(o) == 'table' then
       self.outputModules = o
       for _, m in ipairs(o) do
@@ -50,7 +53,11 @@ function DAG:setOutput(o)
    end
 end
 
-function DAG:order()
+function DAG:sort()
+   if self.sorted then
+      return
+   end
+
    local distance = {}
 
    for _, a in pairs(self.inputModules) do
@@ -81,12 +88,16 @@ function DAG:order()
 end
 
 function DAG:print()
+   self:sort()
+
    for i, d in ipairs(self.sorted) do
       print('#' .. i .. ' -> ' .. torch.type(d))
    end
 end
 
 function DAG:updateOutput(input)
+   self:sort()
+
    if #self.inputModules == 1 then
       self.inputModules[1]:updateOutput(input)
    else
@@ -120,3 +131,9 @@ function DAG:updateOutput(input)
 
    return self.output
 end
+
+function DAG:updateGradInput(input, gradOutput)
+   self:sort()
+end
+
+return DAG
diff --git a/test-dagnn.lua b/test-dagnn.lua
index a0a81ab..262ea6f 100755
--- a/test-dagnn.lua
+++ b/test-dagnn.lua
@@ -21,10 +21,11 @@ f = nn.Linear(3, 2)
                      \---> f ---
 ]]--
 
-g = DAG:new()
+g = nn.DAG:new()
 
 g:setInput(a)
 g:setOutput({ e, f })
+
 g:addEdge(c, e)
 g:addEdge(a, b)
 g:addEdge(d, e)
@@ -32,9 +33,7 @@ g:addEdge(b, c)
 g:addEdge(b, d)
 g:addEdge(d, f)
 
-g:order()
-
-g:print(graph)
+g:print()
 
 input = torch.Tensor(3, 10):uniform()