5 local DAG, parent = torch.class('nn.DAG', 'nn.Container')
13 function DAG:addEdge(a, b)
15 local pred, succ = self.pred, self.succ
16 if not pred[a] and not succ[a] then
19 if not pred[b] and not succ[b] then
22 pred[b] = pred[b] or {}
23 pred[b][#pred[b] + 1] = a
24 succ[a] = succ[a] or {}
25 succ[a][#succ[a] + 1] = b
28 function DAG:applyOnModules(f, t1, t2)
29 if torch.type(t1) == 'table' then
31 for k, s in pairs(t1) do
32 result[k] = self:applyOnModules(f, s, t2 and t2[k])
40 function DAG:setInput(i)
45 function DAG:setOutput(o)
47 self.outputModules = o
57 self:applyOnModules(function(m) distance[m] = 1 end, self.inputModules)
63 for i, isucc in pairs(self.succ) do
64 for _, j in pairs(isucc) do
65 if distance[i] and (not distance[j] or distance[j] < distance[i] + 1) then
66 distance[j] = distance[i] + 1
74 for i, d in pairs(distance) do
75 table.insert(self.sorted, { d, i })
78 table.sort(self.sorted, function(a, b) return a[1] < b[1] end)
79 for i, a in ipairs(self.sorted) do self.sorted[i] = a[2] end
85 for i, d in ipairs(self.sorted) do
86 print('#' .. i .. ' -> ' .. torch.type(d))
90 function DAG:updateOutput(input)
93 self:applyOnModules(function(m, i) m:updateOutput(i) end, self.inputModules, input)
95 for _, d in ipairs(self.sorted) do
97 if #self.pred[d] == 1 then
98 d:updateOutput(self.pred[d][1].output)
99 elseif #self.pred[d] > 1 then
101 for k = 1, #self.pred[d] do
102 c[k] = self.pred[d][k].output
109 self.output = self:applyOnModules(function(m) return m.output end, self.outputModules)
114 function DAG:updateGradInput(input, gradOutput)