end
until nc == 0
+ for _, nnm in pairs(self.modules) do
+ assert(distance[nnm], 'Some modules are not connected to inputs')
+ end
+
self.sorted = {}
for m, d in pairs(distance) do
table.insert(self.sorted, { distance = d, nnm = m })
for i, a in ipairs(self.sorted) do self.sorted[i] = a.nnm end
end
--- This accumulate x in a where they are both nested tables of
--- tensors. If first is true, set a = x.
+-- This accumulates x in a where they are both nested tables of
+-- tensors. If first is true, set a = x. Behavior is undefined if a
+-- and x do not have the exact same structure.
function DAG:nestedAccTensor(a, x, first)
if torch.type(x) == 'table' then
- a = a or {}
+ local b = {}
for i in pairs(x) do
- a[i] = self:nestedAccTensor(a[i], x[i], first)
+ b[i] = self:nestedAccTensor(a[i], x[i], first)
end
+ a = b
else
if first then
if a then
self.inputModules = i
self:nestedApply(
function(nnm)
- if #self.node[nnm].succ == 0 then
- error('Input modules must have outgoing edges.')
- end
- if #self.node[nnm].pred > 0 then
- error('Input modules cannot have incoming edges.')
- end
+ assert(#self.node[nnm].succ > 0, 'Input modules must have outgoing edges.')
+ assert(#self.node[nnm].pred == 0, 'Input modules cannot have incoming edges.')
end,
self.inputModules
)
self.outputModules = o
self:nestedApply(
function(nnm)
- if #self.node[nnm].pred == 0 then
- error('Output module must have incoming edges.')
- end
- if #self.node[nnm].succ > 0 then
- error('Output module cannot have outgoing edges.')
- end
+ assert(#self.node[nnm].pred > 0, 'Output module must have incoming edges.')
+ assert(#self.node[nnm].succ == 0, 'Output module cannot have outgoing edges.')
end,
self.outputModules
)
self:nestedApply(
function(nnm, i)
- self.node[nnm].input = i
- self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i)
+ local node = self.node[nnm]
+ node.input = i
+ self:rethrowErrors(nnm, node.index, 'updateOutput', i)
end,
self.inputModules,
input
end
end
node.input = i
- self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i)
+ self:rethrowErrors(nnm, node.index, 'updateOutput', i)
end
end
function(nnm, go)
local node = self.node[nnm]
node.gradOutput = go
- self:rethrowErrors(nnm, node.index, 'updateGradInput', self.node[nnm].input, go)
+ self:rethrowErrors(nnm, node.index, 'updateGradInput', node.input, go)
end,
self.outputModules, gradOutput
)
if #node.gradInputSucc > 0 then
self:updateGradOutput(node)
- self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', node.input, node.gradOutput)
+ self:rethrowErrors(nnm, node.index, 'updateGradInput', node.input, node.gradOutput)
end
-- We fill the gradInputSucc of our predecessors
if #pred == 1 then
table.insert(self.node[pred[1]].gradInputSucc, nnm.gradInput)
elseif #pred > 1 then
- if not torch.type(nnm.gradInput) == 'table' then
- error('Should have a table gradInput since it has multiple predecessors')
- end
+ assert(torch.type(nnm.gradInput) == 'table',
+ 'Should have a table gradInput since it has multiple predecessors')
for n = 1, #pred do
table.insert(self.node[node.pred[n]].gradInputSucc, nnm.gradInput[n])
end
end
function DAG:accGradParameters(input, gradOutput, scale)
- scale = scale or 1
-
assert(self.sorted, 'There has been a DAG structure change before a DAG:accGradParameters')
self:nestedApply(