From: Francois Fleuret Date: Wed, 11 Jan 2017 07:12:22 +0000 (+0100) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=452781856eafd237579e5c90b6e345354df91b42;p=dagnn.git Update. --- diff --git a/dagnn.lua b/dagnn.lua index 52913ad..4841843 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -11,6 +11,7 @@ function DAG:__init() end function DAG:addEdge(a, b) + self.sorted = nil local pred, succ = self.pred, self.succ if not pred[a] and not succ[a] then self:add(a) @@ -25,6 +26,7 @@ function DAG:addEdge(a, b) end function DAG:setInput(i) + self.sorted = nil if torch.type(i) == 'table' then self.inputModules = i for _, m in ipairs(i) do @@ -38,6 +40,7 @@ function DAG:setInput(i) end function DAG:setOutput(o) + self.sorted = nil if torch.type(o) == 'table' then self.outputModules = o for _, m in ipairs(o) do @@ -50,7 +53,11 @@ function DAG:setOutput(o) end end -function DAG:order() +function DAG:sort() + if self.sorted then + return + end + local distance = {} for _, a in pairs(self.inputModules) do @@ -81,12 +88,16 @@ function DAG:order() end function DAG:print() + self:sort() + for i, d in ipairs(self.sorted) do print('#' .. i .. ' -> ' .. torch.type(d)) end end function DAG:updateOutput(input) + self:sort() + if #self.inputModules == 1 then self.inputModules[1]:updateOutput(input) else @@ -120,3 +131,9 @@ function DAG:updateOutput(input) return self.output end + +function DAG:updateGradInput(input, gradOutput) + self:sort() +end + +return DAG diff --git a/test-dagnn.lua b/test-dagnn.lua index a0a81ab..262ea6f 100755 --- a/test-dagnn.lua +++ b/test-dagnn.lua @@ -21,10 +21,11 @@ f = nn.Linear(3, 2) \---> f --- ]]-- -g = DAG:new() +g = nn.DAG:new() g:setInput(a) g:setOutput({ e, f }) + g:addEdge(c, e) g:addEdge(a, b) g:addEdge(d, e) @@ -32,9 +33,7 @@ g:addEdge(b, c) g:addEdge(b, d) g:addEdge(d, f) -g:order() - -g:print(graph) +g:print() input = torch.Tensor(3, 10):uniform()