X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=graphnn.lua;fp=graphnn.lua;h=0000000000000000000000000000000000000000;hb=c8a895aa5f221f0de11733e8d05373e89ae9476e;hp=1ec9b4ea86f40600f5ff3ac70ed7bfa926841dff;hpb=7a797ca230bfe70bba3e46f41d56c2b1cff3580f;p=dagnn.git diff --git a/graphnn.lua b/graphnn.lua deleted file mode 100755 index 1ec9b4e..0000000 --- a/graphnn.lua +++ /dev/null @@ -1,167 +0,0 @@ -#!/usr/bin/env luajit - -require 'torch' -require 'nn' -require 'image' -require 'optim' - ----------------------------------------------------------------------- - -local Graph, parent = torch.class('nn.Graph', 'nn.Container') - -function Graph:__init() - parent.__init(self) - self.pred = {} - self.succ = {} -end - -function Graph:addEdge(a, b) - local pred, succ = self.pred, self.succ - if not pred[a] and not succ[a] then - self:add(a) - end - if not pred[b] and not succ[b] then - self:add(b) - end - pred[b] = pred[b] or {} - pred[b][#pred[b] + 1] = a - succ[a] = succ[a] or {} - succ[a][#succ[a] + 1] = b -end - -function Graph:setInput(i) - if torch.type(i) == 'table' then - self.inputModules = i - for _, m in ipairs(i) do - if not self.pred[m] and not self.succ[m] then - self:add(m) - end - end - else - self:setInput({ i }) - end -end - -function Graph:setOutput(o) - if torch.type(o) == 'table' then - self.outputModules = o - for _, m in ipairs(o) do - if not self.pred[m] and not self.succ[m] then - self:add(m) - end - end - else - self:setOutput({ o }) - end -end - -function Graph:order() - local distance = {} - - for _, a in pairs(self.inputModules) do - distance[a] = 1 - end - - local nc - - repeat - nc = 0 - for i, isucc in pairs(self.succ) do - for _, j in pairs(isucc) do - if distance[i] and (not distance[j] or distance[j] < distance[i] + 1) then - distance[j] = distance[i] + 1 - nc = nc + 1 - end - end - end - until nc == 0 - - self.sorted = { } - for i, d in pairs(distance) do - table.insert(self.sorted, { d, i }) - end - - table.sort(self.sorted, function(a, b) return a[1] < b[1] end) - for i, a in ipairs(self.sorted) do self.sorted[i] = a[2] end -end - -function Graph:print() - for i, d in ipairs(self.sorted) do - print('#' .. i .. ' -> ' .. torch.type(d)) - end -end - -function Graph:updateOutput(input) - if #self.inputModules == 1 then - self.inputModules[1]:updateOutput(input) - else - for i, d in ipairs(self.inputModules) do - d:updateOutput(input[i]) - end - end - - for _, d in ipairs(self.sorted) do - if self.pred[d] then - if #self.pred[d] == 1 then - d:updateOutput(self.pred[d][1].output) - elseif #self.pred[d] > 1 then - local c = {} - for k = 1, #self.pred[d] do - c[k] = self.pred[d][k].output - end - d:updateOutput(c) - end - end - end - - if #self.outputModules == 1 then - self.output = self.outputModules[1].output - else - self.output = { } - for i, d in ipairs(self.outputModules) do - self.output[i] = d.output - end - end - - return self.output -end - ----------------------------------------------------------------------- - -a = nn.Linear(10, 10) -b = nn.ReLU() -c = nn.Linear(10, 3) -d = nn.Linear(10, 3) -e = nn.CMulTable() -f = nn.Linear(3, 2) - ---[[ - - a -----> b ---> c ----> e --- - \ / - \--> d ---/ - \ - \---> f --- -]]-- - -g = Graph:new() - -g:setInput(a) -g:setOutput({ e, f }) -g:addEdge(c, e) -g:addEdge(a, b) -g:addEdge(d, e) -g:addEdge(b, c) -g:addEdge(b, d) -g:addEdge(d, f) - -g:order() - -g:print(graph) - -input = torch.Tensor(3, 10):uniform() - -output = g:updateOutput(input) - -print(output[1]) -print(output[2])