65a30e212126532f8796f6c1e35208a018f421fe
[dagnn.git] / dagnn.lua
1
2 require 'torch'
3 require 'nn'
4
5 local DAG, parent = torch.class('nn.DAG', 'nn.Container')
6
7 function DAG:__init()
8    parent.__init(self)
9    self.pred = {}
10    self.succ = {}
11 end
12
13 function DAG:addEdge(a, b)
14    self.sorted = nil
15    local pred, succ = self.pred, self.succ
16    if not pred[a] and not succ[a] then
17       self:add(a)
18    end
19    if not pred[b] and not succ[b] then
20       self:add(b)
21    end
22    pred[b] = pred[b] or {}
23    pred[b][#pred[b] + 1] = a
24    succ[a] = succ[a] or {}
25    succ[a][#succ[a] + 1] = b
26 end
27
28 function DAG:applyOnModules(f, t1, t2)
29    if torch.type(t1) == 'table' then
30       local result = {}
31       for k, s in pairs(t1) do
32          result[k] = self:applyOnModules(f, s, t2 and t2[k])
33       end
34       return result
35    else
36       return f(t1, t2)
37    end
38 end
39
40 function DAG:setInput(i)
41    self.sorted = nil
42    self.inputModules = i
43 end
44
45 function DAG:setOutput(o)
46    self.sorted = nil
47    self.outputModules = o
48 end
49
50 function DAG:sort()
51    if self.sorted then
52       return
53    end
54
55    local distance = {}
56
57    self:applyOnModules(function(m) distance[m] = 1 end, self.inputModules)
58
59    local nc
60
61    repeat
62       nc = 0
63       for i, isucc in pairs(self.succ) do
64          for _, j in pairs(isucc) do
65             if distance[i] and (not distance[j] or distance[j] < distance[i] + 1) then
66                distance[j] = distance[i] + 1
67                nc = nc + 1
68             end
69          end
70       end
71    until nc == 0
72
73    self.sorted = { }
74    for i, d in pairs(distance) do
75       table.insert(self.sorted, { d, i })
76    end
77
78    table.sort(self.sorted, function(a, b) return a[1] < b[1] end)
79    for i, a in ipairs(self.sorted) do self.sorted[i] = a[2] end
80 end
81
82 function DAG:print()
83    self:sort()
84
85    for i, d in ipairs(self.sorted) do
86       print('#' .. i .. ' -> ' .. torch.type(d))
87    end
88 end
89
90 function DAG:updateOutput(input)
91    self:sort()
92
93    self:applyOnModules(function(m, i) m:updateOutput(i) end, self.inputModules, input)
94
95    for _, d in ipairs(self.sorted) do
96       if self.pred[d] then
97          if #self.pred[d] == 1 then
98             d:updateOutput(self.pred[d][1].output)
99          elseif #self.pred[d] > 1 then
100             local c = {}
101             for k = 1, #self.pred[d] do
102                c[k] = self.pred[d][k].output
103             end
104             d:updateOutput(c)
105          end
106       end
107    end
108
109    self.output = self:applyOnModules(function(m) return m.output end, self.outputModules)
110
111    return self.output
112 end
113
114 function DAG:updateGradInput(input, gradOutput)
115    self:sort()
116 end
117
118 return DAG