for i, a in ipairs(self.sorted) do self.sorted[i] = a.nnm end
end
--- This accumulate x in a where they are both nested tables of
--- tensors. If first is true, set a = x.
+-- This accumulates x in a where they are both nested tables of
+-- tensors. If first is true, set a = x. Behavior is undefined if a
+-- and x do not have the exact same structure.
function DAG:nestedAccTensor(a, x, first)
if torch.type(x) == 'table' then
- a = a or {}
+ local b = {}
for i in pairs(x) do
- a[i] = self:nestedAccTensor(a[i], x[i], first)
+ b[i] = self:nestedAccTensor(a[i], x[i], first)
end
+ a = b
else
if first then
if a then
self:nestedApply(
function(nnm, i)
- self.node[nnm].input = i
- self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i)
+ local node = self.node[nnm]
+ node.input = i
+ self:rethrowErrors(nnm, node.index, 'updateOutput', i)
end,
self.inputModules,
input
end
end
node.input = i
- self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i)
+ self:rethrowErrors(nnm, node.index, 'updateOutput', i)
end
end
function(nnm, go)
local node = self.node[nnm]
node.gradOutput = go
- self:rethrowErrors(nnm, node.index, 'updateGradInput', self.node[nnm].input, go)
+ self:rethrowErrors(nnm, node.index, 'updateGradInput', node.input, go)
end,
self.outputModules, gradOutput
)
if #node.gradInputSucc > 0 then
self:updateGradOutput(node)
- self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', node.input, node.gradOutput)
+ self:rethrowErrors(nnm, node.index, 'updateGradInput', node.input, node.gradOutput)
end
-- We fill the gradInputSucc of our predecessors
end
function DAG:accGradParameters(input, gradOutput, scale)
- scale = scale or 1
-
assert(self.sorted, 'There has been a DAG structure change before a DAG:accGradParameters')
self:nestedApply(