function Graph:setInput(i)
if torch.type(i) == 'table' then
- self.input = i
+ self.inputModules = i
for _, m in ipairs(i) do
if not self.pred[m] and not self.succ[m] then
self:add(m)
function Graph:setOutput(o)
if torch.type(o) == 'table' then
- self.output = o
+ self.outputModules = o
for _, m in ipairs(o) do
if not self.pred[m] and not self.succ[m] then
self:add(m)
function Graph:order()
local distance = {}
- for _, a in pairs(self.input) do
+ for _, a in pairs(self.inputModules) do
distance[a] = 1
end
end
function Graph:updateOutput(input)
- return self.output.output
+ if #self.inputModules == 1 then
+ self.inputModules[1]:updateOutput(input)
+ else
+ for i, d in ipairs(self.inputModules) do
+ d:updateOutput(input[i])
+ end
+ end
+
+ for _, d in ipairs(self.sorted) do
+ if self.pred[d] then
+ if #self.pred[d] == 1 then
+ d:updateOutput(self.pred[d][1].output)
+ elseif #self.pred[d] > 1 then
+ local c = {}
+ for k = 1, #self.pred[d] do
+ c[k] = self.pred[d][k].output
+ end
+ d:updateOutput(c)
+ end
+ end
+ end
+
+ if #self.outputModules == 1 then
+ self.output = self.outputModules[1].output
+ else
+ self.output = { }
+ for i, d in ipairs(self.inputModules) do
+ self.output[i] = d.output
+ end
+ end
+
+ return self.output
end
----------------------------------------------------------------------
--[[
a -----> b ---> c ---- e ---
- \ /
- \--> d ---/
+ \ /
+ \--> d ---/
]]--
g:order()
g:print(graph)
+
+input = torch.Tensor(3, 10):uniform()
+
+output = g:updateOutput(input)
+
+print(output)