1b467e720542469d45228b1dbc8a8fd0b021f6ad
[dagnn.git] / dagnn.lua
1
2 require 'torch'
3 require 'nn'
4
5 local DAG, parent = torch.class('nn.DAG', 'nn.Container')
6
7 function DAG:__init()
8    parent.__init(self)
9    self.pred = {}
10    self.succ = {}
11 end
12
13 function DAG:addEdge(a, b)
14    self.sorted = nil
15    local pred, succ = self.pred, self.succ
16    if not pred[a] and not succ[a] then
17       self:add(a)
18    end
19    if not pred[b] and not succ[b] then
20       self:add(b)
21    end
22    pred[b] = pred[b] or {}
23    pred[b][#pred[b] + 1] = a
24    succ[a] = succ[a] or {}
25    succ[a][#succ[a] + 1] = b
26 end
27
28 -- Apply f on t recursively; use the corresponding a1 and a2 elements
29 -- (i.e. same keys) as second and third parameters to f when
30 -- available; return the results from f, organized in a similarly
31 -- nested table.
32 function DAG:applyOnModules(f, t, a1, a2)
33    if torch.type(t) == 'table' then
34       local result = {}
35       for k, s in pairs(t) do
36          result[k] = self:applyOnModules(f, s, a1 and a1[k], a2 and a2[k])
37       end
38       return result
39    else
40       return f(t, a1, a2)
41    end
42 end
43
44 function DAG:setInput(i)
45    self.sorted = nil
46    self.inputModules = i
47    self:applyOnModules(
48       function(m)
49          if not self.succ[m] or #self.succ[m] == 0 then
50             error('Input modules must have outgoing  edges.')
51          end
52          if self.pred[m] and #self.pred[m] > 0 then
53             error('Input modules cannog have incoming edges.')
54          end
55       end,
56       self.inputModules
57    )
58 end
59
60 function DAG:setOutput(o)
61    self.sorted = nil
62    self.outputModules = o
63    self:applyOnModules(
64       function(m)
65          if not self.pred[m] or #self.pred[m] == 0 then
66             error('Output module must have incoming edges.')
67          end
68          if self.succ[m] and #self.succ[m] > 0 then
69             error('Output module cannot have outgoing edges.')
70          end
71       end,
72       self.outputModules
73    )
74 end
75
76 function DAG:sort()
77    if self.sorted then
78       return
79    end
80
81    local distance = {}
82
83    self:applyOnModules(function(m) distance[m] = 1 end, self.inputModules)
84
85    local nc
86
87    repeat
88       nc = 0
89       for i, isucc in pairs(self.succ) do
90          for _, j in pairs(isucc) do
91             if distance[i] and (not distance[j] or distance[j] < distance[i] + 1) then
92                distance[j] = distance[i] + 1
93                nc = nc + 1
94             end
95          end
96       end
97    until nc == 0
98
99    self.sorted = { }
100    for i, d in pairs(distance) do
101       table.insert(self.sorted, { d, i })
102    end
103
104    table.sort(self.sorted, function(a, b) return a[1] < b[1] end)
105    for i, a in ipairs(self.sorted) do self.sorted[i] = a[2] end
106 end
107
108 function DAG:print()
109    self:sort()
110
111    for i, d in ipairs(self.sorted) do
112       print('#' .. i .. ' -> ' .. torch.type(d))
113    end
114 end
115
116 function DAG:updateOutput(input)
117    self:sort()
118
119    self:applyOnModules(function(m, i) m:updateOutput(i) end, self.inputModules, input)
120
121    for _, d in ipairs(self.sorted) do
122       if self.pred[d] then
123          if #self.pred[d] == 1 then
124             d:updateOutput(self.pred[d][1].output)
125          elseif #self.pred[d] > 1 then
126             local c = {}
127             for k = 1, #self.pred[d] do
128                c[k] = self.pred[d][k].output
129             end
130             d:updateOutput(c)
131          end
132       end
133    end
134
135    self.output = self:applyOnModules(function(m) return m.output end, self.outputModules)
136
137    return self.output
138 end
139
140 function DAG:updateGradInput(input, gradOutput)
141    self:sort()
142
143    self:applyOnModules(
144       function(m, i, go) m:updateGradInput(i, go) end,
145       self.outputModules, input, gradOutput
146    )
147
148    for k = self.sorted, 1, -1 do
149       local m = sorted[k]
150       if self.succ[d] then
151          if #self.succ[d] == 1 then
152             d:updateGradInput(self.succ[d][1].gradInput)
153          elseif #self.succ[d] > 1 then
154             local sum
155             for k = 1, #self.succ[d] do
156                if sum then
157                   sum:add(self.succ[d][k].gradInput)
158                else
159                   sum = self.succ[d][k].gradInput:clone()
160                end
161             end
162             d:updateGradInput(sum)
163          end
164       end
165    end
166
167    self.gradInput = self:applyOnModules(function(m) return m.gradInput end, self.inputModules)
168
169    return self.gradInput
170 end
171
172 return DAG