From: Francois Fleuret Date: Tue, 10 Jan 2017 21:35:54 +0000 (+0100) Subject: Initial commit. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=48eecc7b955278730154a150a2cd7d1fa3d8a88e;p=dagnn.git Initial commit. --- 48eecc7b955278730154a150a2cd7d1fa3d8a88e diff --git a/graphnn.lua b/graphnn.lua new file mode 100755 index 0000000..a3ee1c1 --- /dev/null +++ b/graphnn.lua @@ -0,0 +1,125 @@ +#!/usr/bin/env luajit + +require 'torch' +require 'nn' +require 'image' +require 'optim' + +---------------------------------------------------------------------- + +local Graph, parent = torch.class('nn.Graph', 'nn.Container') + +function Graph:__init() + parent.__init(self) + self.pred = {} + self.succ = {} +end + +function Graph:addEdge(a, b) + local pred, succ = self.pred, self.succ + if not pred[a] and not succ[a] then + self:add(a) + end + if not pred[b] and not succ[b] then + self:add(b) + end + pred[b] = pred[b] or {} + pred[b][#pred[b] + 1] = a + succ[a] = succ[a] or {} + succ[a][#succ[a] + 1] = b +end + +function Graph:setInput(i) + if torch.type(i) == 'table' then + self.input = i + for _, m in ipairs(i) do + if not self.pred[m] and not self.succ[m] then + self:add(m) + end + end + else + self:setInput({ i }) + end +end + +function Graph:setOutput(o) + if torch.type(o) == 'table' then + self.output = o + for _, m in ipairs(o) do + if not self.pred[m] and not self.succ[m] then + self:add(m) + end + end + else + self:setOutput({ o }) + end +end + +function Graph:order() + local distance = {} + + for _, a in pairs(self.input) do + distance[a] = 1 + end + + local nc + + repeat + nc = 0 + for i, isucc in pairs(self.succ) do + for _, j in pairs(isucc) do + if distance[i] and (not distance[j] or distance[j] < distance[i] + 1) then + distance[j] = distance[i] + 1 + nc = nc + 1 + end + end + end + until nc == 0 + + self.sorted = { } + for i, d in pairs(distance) do + table.insert(self.sorted, { d, i }) + end + + table.sort(self.sorted, function(a, b) return a[1] < b[1] end) + for i, a in ipairs(self.sorted) do self.sorted[i] = a[2] end +end + +function Graph:print() + for i, d in ipairs(self.sorted) do + print('#' .. i .. ' -> ' .. torch.type(d)) + end +end + +function Graph:updateOutput(input) + return self.output.output +end + +---------------------------------------------------------------------- + +a = nn.Linear(10, 10) +b = nn.ReLU() +c = nn.Linear(10, 3) +d = nn.Linear(10, 3) +e = nn.CMulTable() + +--[[ + + a -----> b ---> c ---- e --- + \ / + \--> d ---/ + +]]-- + +g = Graph:new() + +g:setInput(a) +g:setOutput(e) +g:addEdge(c, e) +g:addEdge(a, b) +g:addEdge(d, e) +g:addEdge(b, c) +g:addEdge(b, d) + +g:order() +g:print(graph)