Update.
authorFrancois Fleuret <francois@fleuret.org>
Tue, 10 Jan 2017 21:58:34 +0000 (22:58 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Tue, 10 Jan 2017 21:58:34 +0000 (22:58 +0100)
graphnn.lua

index 4981de4..1003500 100755 (executable)
@@ -118,7 +118,7 @@ function Graph:updateOutput(input)
       self.output = self.outputModules[1].output
    else
       self.output = { }
-      for i, d in ipairs(self.inputModules) do
+      for i, d in ipairs(self.outputModules) do
          self.output[i] = d.output
       end
    end
@@ -133,30 +133,34 @@ b = nn.ReLU()
 c = nn.Linear(10, 3)
 d = nn.Linear(10, 3)
 e = nn.CMulTable()
+f = nn.Linear(3, 2)
 
 --[[
 
    a -----> b ---> c ---- e ---
-   \           /
-   \--> d ---/
+             \           /
+              \--> d ---/
 
 ]]--
 
 g = Graph:new()
 
 g:setInput(a)
-g:setOutput(e)
+g:setOutput({ e, f })
 g:addEdge(c, e)
 g:addEdge(a, b)
 g:addEdge(d, e)
 g:addEdge(b, c)
 g:addEdge(b, d)
+g:addEdge(d, f)
 
 g:order()
+
 g:print(graph)
 
 input = torch.Tensor(3, 10):uniform()
 
 output = g:updateOutput(input)
 
-print(output)
+print(output[1])
+print(output[2])