X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=dagnn.lua;h=1b467e720542469d45228b1dbc8a8fd0b021f6ad;hb=e5030cca047eed4b8c5db172fc52e893b1b1d843;hp=1ec9b4ea86f40600f5ff3ac70ed7bfa926841dff;hpb=c8a895aa5f221f0de11733e8d05373e89ae9476e;p=dagnn.git diff --git a/dagnn.lua b/dagnn.lua index 1ec9b4e..1b467e7 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -1,21 +1,17 @@ -#!/usr/bin/env luajit require 'torch' require 'nn' -require 'image' -require 'optim' ----------------------------------------------------------------------- +local DAG, parent = torch.class('nn.DAG', 'nn.Container') -local Graph, parent = torch.class('nn.Graph', 'nn.Container') - -function Graph:__init() +function DAG:__init() parent.__init(self) self.pred = {} self.succ = {} end -function Graph:addEdge(a, b) +function DAG:addEdge(a, b) + self.sorted = nil local pred, succ = self.pred, self.succ if not pred[a] and not succ[a] then self:add(a) @@ -29,39 +25,63 @@ function Graph:addEdge(a, b) succ[a][#succ[a] + 1] = b end -function Graph:setInput(i) - if torch.type(i) == 'table' then - self.inputModules = i - for _, m in ipairs(i) do - if not self.pred[m] and not self.succ[m] then - self:add(m) - end +-- Apply f on t recursively; use the corresponding a1 and a2 elements +-- (i.e. same keys) as second and third parameters to f when +-- available; return the results from f, organized in a similarly +-- nested table. +function DAG:applyOnModules(f, t, a1, a2) + if torch.type(t) == 'table' then + local result = {} + for k, s in pairs(t) do + result[k] = self:applyOnModules(f, s, a1 and a1[k], a2 and a2[k]) end + return result else - self:setInput({ i }) + return f(t, a1, a2) end end -function Graph:setOutput(o) - if torch.type(o) == 'table' then - self.outputModules = o - for _, m in ipairs(o) do - if not self.pred[m] and not self.succ[m] then - self:add(m) +function DAG:setInput(i) + self.sorted = nil + self.inputModules = i + self:applyOnModules( + function(m) + if not self.succ[m] or #self.succ[m] == 0 then + error('Input modules must have outgoing edges.') end - end - else - self:setOutput({ o }) - end + if self.pred[m] and #self.pred[m] > 0 then + error('Input modules cannog have incoming edges.') + end + end, + self.inputModules + ) end -function Graph:order() - local distance = {} +function DAG:setOutput(o) + self.sorted = nil + self.outputModules = o + self:applyOnModules( + function(m) + if not self.pred[m] or #self.pred[m] == 0 then + error('Output module must have incoming edges.') + end + if self.succ[m] and #self.succ[m] > 0 then + error('Output module cannot have outgoing edges.') + end + end, + self.outputModules + ) +end - for _, a in pairs(self.inputModules) do - distance[a] = 1 +function DAG:sort() + if self.sorted then + return end + local distance = {} + + self:applyOnModules(function(m) distance[m] = 1 end, self.inputModules) + local nc repeat @@ -85,20 +105,18 @@ function Graph:order() for i, a in ipairs(self.sorted) do self.sorted[i] = a[2] end end -function Graph:print() +function DAG:print() + self:sort() + for i, d in ipairs(self.sorted) do print('#' .. i .. ' -> ' .. torch.type(d)) end end -function Graph:updateOutput(input) - if #self.inputModules == 1 then - self.inputModules[1]:updateOutput(input) - else - for i, d in ipairs(self.inputModules) do - d:updateOutput(input[i]) - end - end +function DAG:updateOutput(input) + self:sort() + + self:applyOnModules(function(m, i) m:updateOutput(i) end, self.inputModules, input) for _, d in ipairs(self.sorted) do if self.pred[d] then @@ -114,54 +132,41 @@ function Graph:updateOutput(input) end end - if #self.outputModules == 1 then - self.output = self.outputModules[1].output - else - self.output = { } - for i, d in ipairs(self.outputModules) do - self.output[i] = d.output - end - end + self.output = self:applyOnModules(function(m) return m.output end, self.outputModules) return self.output end ----------------------------------------------------------------------- - -a = nn.Linear(10, 10) -b = nn.ReLU() -c = nn.Linear(10, 3) -d = nn.Linear(10, 3) -e = nn.CMulTable() -f = nn.Linear(3, 2) - ---[[ - - a -----> b ---> c ----> e --- - \ / - \--> d ---/ - \ - \---> f --- -]]-- - -g = Graph:new() - -g:setInput(a) -g:setOutput({ e, f }) -g:addEdge(c, e) -g:addEdge(a, b) -g:addEdge(d, e) -g:addEdge(b, c) -g:addEdge(b, d) -g:addEdge(d, f) - -g:order() - -g:print(graph) +function DAG:updateGradInput(input, gradOutput) + self:sort() + + self:applyOnModules( + function(m, i, go) m:updateGradInput(i, go) end, + self.outputModules, input, gradOutput + ) + + for k = self.sorted, 1, -1 do + local m = sorted[k] + if self.succ[d] then + if #self.succ[d] == 1 then + d:updateGradInput(self.succ[d][1].gradInput) + elseif #self.succ[d] > 1 then + local sum + for k = 1, #self.succ[d] do + if sum then + sum:add(self.succ[d][k].gradInput) + else + sum = self.succ[d][k].gradInput:clone() + end + end + d:updateGradInput(sum) + end + end + end -input = torch.Tensor(3, 10):uniform() + self.gradInput = self:applyOnModules(function(m) return m.gradInput end, self.inputModules) -output = g:updateOutput(input) + return self.gradInput +end -print(output[1]) -print(output[2]) +return DAG