From: Francois Fleuret Date: Tue, 10 Jan 2017 21:55:31 +0000 (+0100) Subject: Graph:updateOutput seems to work (!) X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=b03b71f6af6f9747fc56dfffcb0e296a87725088;p=dagnn.git Graph:updateOutput seems to work (!) --- diff --git a/graphnn.lua b/graphnn.lua index a3ee1c1..4981de4 100755 --- a/graphnn.lua +++ b/graphnn.lua @@ -31,7 +31,7 @@ end function Graph:setInput(i) if torch.type(i) == 'table' then - self.input = i + self.inputModules = i for _, m in ipairs(i) do if not self.pred[m] and not self.succ[m] then self:add(m) @@ -44,7 +44,7 @@ end function Graph:setOutput(o) if torch.type(o) == 'table' then - self.output = o + self.outputModules = o for _, m in ipairs(o) do if not self.pred[m] and not self.succ[m] then self:add(m) @@ -58,7 +58,7 @@ end function Graph:order() local distance = {} - for _, a in pairs(self.input) do + for _, a in pairs(self.inputModules) do distance[a] = 1 end @@ -92,7 +92,38 @@ function Graph:print() end function Graph:updateOutput(input) - return self.output.output + if #self.inputModules == 1 then + self.inputModules[1]:updateOutput(input) + else + for i, d in ipairs(self.inputModules) do + d:updateOutput(input[i]) + end + end + + for _, d in ipairs(self.sorted) do + if self.pred[d] then + if #self.pred[d] == 1 then + d:updateOutput(self.pred[d][1].output) + elseif #self.pred[d] > 1 then + local c = {} + for k = 1, #self.pred[d] do + c[k] = self.pred[d][k].output + end + d:updateOutput(c) + end + end + end + + if #self.outputModules == 1 then + self.output = self.outputModules[1].output + else + self.output = { } + for i, d in ipairs(self.inputModules) do + self.output[i] = d.output + end + end + + return self.output end ---------------------------------------------------------------------- @@ -106,8 +137,8 @@ e = nn.CMulTable() --[[ a -----> b ---> c ---- e --- - \ / - \--> d ---/ + \ / + \--> d ---/ ]]-- @@ -123,3 +154,9 @@ g:addEdge(b, d) g:order() g:print(graph) + +input = torch.Tensor(3, 10):uniform() + +output = g:updateOutput(input) + +print(output)