Update.
[dagnn.git] / dagnn.lua
1
2 require 'torch'
3 require 'nn'
4
5 local DAG, parent = torch.class('nn.DAG', 'nn.Container')
6
7 function DAG:__init()
8    parent.__init(self)
9    -- Nodes are indexed by the module they encompass
10    self.node = { }
11 end
12
13 function DAG:createNode(n)
14    if not self.node[n] then
15       self:add(n) -- Add it to the object as a Container
16       self.node[n] = {}
17       self.node[n].succ = {}
18       self.node[n].pred = {}
19    end
20 end
21
22 function DAG:addEdge(a, b)
23    self.sorted = nil
24    self:createNode(a)
25    self:createNode(b)
26    table.insert(self.node[b].pred, a)
27    table.insert(self.node[a].succ, b)
28 end
29
30 -- Apply f on t recursively; use the corresponding a1 and a2 elements
31 -- (i.e. same keys) as second and third parameters to f when
32 -- available; return the results from f, organized in a similarly
33 -- nested table.
34 function DAG:nestApply(f, t, a1, a2)
35    if torch.type(t) == 'table' then
36       local result = {}
37       for k, s in pairs(t) do
38          result[k] = self:nestApply(f, s, a1 and a1[k], a2 and a2[k])
39       end
40       return result
41    else
42       return f(t, a1, a2)
43    end
44 end
45
46 function DAG:setInput(i)
47    self.sorted = nil
48    self.inputModules = i
49    self:nestApply(
50       function(m)
51          if #self.node[m].succ == 0 then
52             error('Input modules must have outgoing  edges.')
53          end
54          if #self.node[m].pred > 0 then
55             error('Input modules cannog have incoming edges.')
56          end
57       end,
58       self.inputModules
59    )
60 end
61
62 function DAG:setOutput(o)
63    self.sorted = nil
64    self.outputModules = o
65    self:nestApply(
66       function(m)
67          if #self.node[m].pred == 0 then
68             error('Output module must have incoming edges.')
69          end
70          if #self.node[m].succ > 0 then
71             error('Output module cannot have outgoing edges.')
72          end
73       end,
74       self.outputModules
75    )
76 end
77
78 function DAG:putInOrder()
79    if self.sorted then
80       return
81    end
82
83    -- First, we sort the nodes according to the DAG order
84
85    local distance = {}
86
87    self:nestApply(function(m) distance[m] = 1 end, self.inputModules)
88
89    local nc
90
91    repeat
92       nc = 0
93       for i, node in pairs(self.node) do
94          for _, j in pairs(node.succ) do
95             if distance[i] and (not distance[j] or distance[j] < distance[i] + 1) then
96                distance[j] = distance[i] + 1
97                nc = nc + 1
98             end
99          end
100       end
101    until nc == 0
102
103    self.sorted = { }
104    for n, d in pairs(distance) do
105       table.insert(self.sorted, { distance = d, node = n })
106    end
107
108    table.sort(self.sorted, function(a, b) return a.distance < b.distance end)
109
110    for i, a in ipairs(self.sorted) do self.sorted[i] = a.node end
111 end
112
113 function DAG:print()
114    self:putInOrder()
115
116    for i, d in ipairs(self.sorted) do
117       print('#' .. i .. ' -> ' .. torch.type(d))
118    end
119 end
120
121 function DAG:updateOutput(input)
122    self:putInOrder()
123
124    self:nestApply(function(m, i) m:updateOutput(i) end, self.inputModules, input)
125
126    for _, m in ipairs(self.sorted) do
127       if #self.node[m].pred > 0 then
128          local i
129          if #self.node[m].pred == 1 then
130             i = self.node[m].pred[1].output
131          elseif #self.node[m].pred > 1 then
132             i = {}
133             for k = 1, #self.node[m].pred do
134                i[k] = self.node[m].pred[k].output
135             end
136          end
137          self.node[m].input = i
138          m:updateOutput(i)
139       end
140    end
141
142    self.output = self:nestApply(function(m) return m.output end, self.outputModules)
143
144    return self.output
145 end
146
147 function DAG:updateGradInput(input, gradOutput)
148    self:putInOrder()
149
150    self:nestApply(
151       function(m, go) m:updateGradInput(self.node[m].input, go) end,
152       self.outputModules, gradOutput
153    )
154
155    for _, node in pairs(self.node) do
156       node.gradInputSucc = {}
157    end
158
159    for k = #self.sorted, 1, -1 do
160       local m = self.sorted[k]
161       local node = self.node[m]
162       local pred, succ, gradInputSucc = node.pred, node.succ, node.gradInputSucc
163
164       -- We update m:gradInput
165       if #gradInputSucc == 1 then
166          m:updateGradInput(node.input, gradInputSucc[1])
167       elseif #gradInputSucc > 1 then
168          local sum
169          for k = 1, #succ do
170             if sum then
171                sum:add(succ[k].gradInput)
172             else
173                sum = succ[k].gradInput
174             end
175          end
176          m:updateGradInput(node.input, sum)
177       end
178
179       -- We fill the gradInputSucc of our predecessors
180       if #pred == 1 then
181          table.insert(self.node[pred[1]].gradInputSucc, node.gradInput)
182       elseif #pred > 1 then
183          for n = 1, #pred do
184             table.insert(self.node[node.pred[n]].gradInputSucc, m.gradInput[n])
185          end
186       end
187    end
188
189    self.gradInput = self:nestApply(function(m) return m.gradInput end, self.inputModules)
190
191    return self.gradInput
192 end
193
194 return DAG