X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=graphnn.lua;h=1003500e6a876eba02dafbe6ffe92e5db66776d4;hb=a29cfbfae14bba7d8a8c8b3e596fa1fbe79aa637;hp=a3ee1c152c79860eb70904d81a224431a57c9c03;hpb=48eecc7b955278730154a150a2cd7d1fa3d8a88e;p=dagnn.git diff --git a/graphnn.lua b/graphnn.lua index a3ee1c1..1003500 100755 --- a/graphnn.lua +++ b/graphnn.lua @@ -31,7 +31,7 @@ end function Graph:setInput(i) if torch.type(i) == 'table' then - self.input = i + self.inputModules = i for _, m in ipairs(i) do if not self.pred[m] and not self.succ[m] then self:add(m) @@ -44,7 +44,7 @@ end function Graph:setOutput(o) if torch.type(o) == 'table' then - self.output = o + self.outputModules = o for _, m in ipairs(o) do if not self.pred[m] and not self.succ[m] then self:add(m) @@ -58,7 +58,7 @@ end function Graph:order() local distance = {} - for _, a in pairs(self.input) do + for _, a in pairs(self.inputModules) do distance[a] = 1 end @@ -92,7 +92,38 @@ function Graph:print() end function Graph:updateOutput(input) - return self.output.output + if #self.inputModules == 1 then + self.inputModules[1]:updateOutput(input) + else + for i, d in ipairs(self.inputModules) do + d:updateOutput(input[i]) + end + end + + for _, d in ipairs(self.sorted) do + if self.pred[d] then + if #self.pred[d] == 1 then + d:updateOutput(self.pred[d][1].output) + elseif #self.pred[d] > 1 then + local c = {} + for k = 1, #self.pred[d] do + c[k] = self.pred[d][k].output + end + d:updateOutput(c) + end + end + end + + if #self.outputModules == 1 then + self.output = self.outputModules[1].output + else + self.output = { } + for i, d in ipairs(self.outputModules) do + self.output[i] = d.output + end + end + + return self.output end ---------------------------------------------------------------------- @@ -102,6 +133,7 @@ b = nn.ReLU() c = nn.Linear(10, 3) d = nn.Linear(10, 3) e = nn.CMulTable() +f = nn.Linear(3, 2) --[[ @@ -114,12 +146,21 @@ e = nn.CMulTable() g = Graph:new() g:setInput(a) -g:setOutput(e) +g:setOutput({ e, f }) g:addEdge(c, e) g:addEdge(a, b) g:addEdge(d, e) g:addEdge(b, c) g:addEdge(b, d) +g:addEdge(d, f) g:order() + g:print(graph) + +input = torch.Tensor(3, 10):uniform() + +output = g:updateOutput(input) + +print(output[1]) +print(output[2])