Graph:updateOutput seems to work (!)
authorFrancois Fleuret <francois@fleuret.org>
Tue, 10 Jan 2017 21:55:31 +0000 (22:55 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Tue, 10 Jan 2017 21:55:31 +0000 (22:55 +0100)
graphnn.lua

index a3ee1c1..4981de4 100755 (executable)
@@ -31,7 +31,7 @@ end
 
 function Graph:setInput(i)
    if torch.type(i) == 'table' then
-      self.input = i
+      self.inputModules = i
       for _, m in ipairs(i) do
          if not self.pred[m] and not self.succ[m] then
             self:add(m)
@@ -44,7 +44,7 @@ end
 
 function Graph:setOutput(o)
    if torch.type(o) == 'table' then
-      self.output = o
+      self.outputModules = o
       for _, m in ipairs(o) do
          if not self.pred[m] and not self.succ[m] then
             self:add(m)
@@ -58,7 +58,7 @@ end
 function Graph:order()
    local distance = {}
 
-   for _, a in pairs(self.input) do
+   for _, a in pairs(self.inputModules) do
       distance[a] = 1
    end
 
@@ -92,7 +92,38 @@ function Graph:print()
 end
 
 function Graph:updateOutput(input)
-   return self.output.output
+   if #self.inputModules == 1 then
+      self.inputModules[1]:updateOutput(input)
+   else
+      for i, d in ipairs(self.inputModules) do
+         d:updateOutput(input[i])
+      end
+   end
+
+   for _, d in ipairs(self.sorted) do
+      if self.pred[d] then
+         if #self.pred[d] == 1 then
+            d:updateOutput(self.pred[d][1].output)
+         elseif #self.pred[d] > 1 then
+            local c = {}
+            for k = 1, #self.pred[d] do
+               c[k] = self.pred[d][k].output
+            end
+            d:updateOutput(c)
+         end
+      end
+   end
+
+   if #self.outputModules == 1 then
+      self.output = self.outputModules[1].output
+   else
+      self.output = { }
+      for i, d in ipairs(self.inputModules) do
+         self.output[i] = d.output
+      end
+   end
+
+   return self.output
 end
 
 ----------------------------------------------------------------------
@@ -106,8 +137,8 @@ e = nn.CMulTable()
 --[[
 
    a -----> b ---> c ---- e ---
-             \           /
-              \--> d ---/
+   \           /
+   \--> d ---/
 
 ]]--
 
@@ -123,3 +154,9 @@ g:addEdge(b, d)
 
 g:order()
 g:print(graph)
+
+input = torch.Tensor(3, 10):uniform()
+
+output = g:updateOutput(input)
+
+print(output)