projects
/
dagnn.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (from parent 1:
3c69b13
)
Remove the clone() for node.gradOutput when possible.
author
Francois Fleuret
<francois@fleuret.org>
Fri, 13 Jan 2017 15:50:57 +0000
(16:50 +0100)
committer
Francois Fleuret
<francois@fleuret.org>
Fri, 13 Jan 2017 15:50:57 +0000
(16:50 +0100)
dagnn.lua
patch
|
blob
|
history
diff --git
a/dagnn.lua
b/dagnn.lua
index
0c1d153
..
de9d29b
100755
(executable)
--- a/
dagnn.lua
+++ b/
dagnn.lua
@@
-86,20
+86,20
@@
function DAG:putInOrder()
for i, a in ipairs(self.sorted) do self.sorted[i] = a.nnm end
end
for i, a in ipairs(self.sorted) do self.sorted[i] = a.nnm end
end
-function DAG:
computeGradOutput(gradInputSucc
)
- local g
i
+function DAG:
updateGradOutput(node
)
+ local g
radInputSucc = node.gradInputSucc
if #gradInputSucc == 1 then
if #gradInputSucc == 1 then
- gi = gradInputSucc[1] -- we avoid a clone()
+ node.gradOutput = gradInputSucc[1]
elseif #gradInputSucc > 1 then
elseif #gradInputSucc > 1 then
- for k = 1, #gradInputSucc do
- if gi then
- gi:add(gradInputSucc[k])
- else
- gi = gradInputSucc[k]:clone()
- end
+ if node.gradOutput then
+ node.gradOutput:resize(gradInputSucc[1]):copy(gradInputSucc[1])
+ else
+ node.gradOutput = gradInputSucc[1]:clone()
+ end
+ for k = 2, #gradInputSucc do
+ node.gradOutput:add(gradInputSucc[k])
end
end
end
end
- return gi
end
----------------------------------------------------------------------
end
----------------------------------------------------------------------
@@
-206,7
+206,6
@@
function DAG:updateOutput(input)
self:nestedApply(
function(nnm, i)
self.node[nnm].input = i
self:nestedApply(
function(nnm, i)
self.node[nnm].input = i
- -- nnm:updateOutput(i)
self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i)
end,
self.inputModules,
self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i)
end,
self.inputModules,
@@
-226,7
+225,6
@@
function DAG:updateOutput(input)
end
end
node.input = i
end
end
node.input = i
- -- nnm:updateOutput(i)
self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i)
end
end
self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i)
end
end
@@
-246,7
+244,6
@@
function DAG:updateGradInput(input, gradOutput)
function(nnm, go)
local node = self.node[nnm]
node.gradOutput = go
function(nnm, go)
local node = self.node[nnm]
node.gradOutput = go
- -- nnm:updateGradInput(self.node[nnm].input, go)
self:rethrowErrors(nnm, node.index, 'updateGradInput', self.node[nnm].input, go)
end,
self.outputModules, gradOutput
self:rethrowErrors(nnm, node.index, 'updateGradInput', self.node[nnm].input, go)
end,
self.outputModules, gradOutput
@@
-264,11
+261,10
@@
function DAG:updateGradInput(input, gradOutput)
for k = #self.sorted, 1, -1 do
local nnm = self.sorted[k]
local node = self.node[nnm]
for k = #self.sorted, 1, -1 do
local nnm = self.sorted[k]
local node = self.node[nnm]
- local pred
, gradInputSucc = node.pred, node.gradInputSucc
+ local pred
= node.pred
- if #gradInputSucc > 0 then
- node.gradOutput = self:computeGradOutput(gradInputSucc)
- -- nnm:updateGradInput(node.input, node.gradOutput)
+ if #node.gradInputSucc > 0 then
+ self:updateGradOutput(node)
self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', node.input, node.gradOutput)
end
self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', node.input, node.gradOutput)
end
@@
-308,7
+304,6
@@
function DAG:accGradParameters(input, gradOutput, scale)
for k = 1, #self.modules do
local nnm = self.modules[k]
local node = self.node[nnm]
for k = 1, #self.modules do
local nnm = self.modules[k]
local node = self.node[nnm]
- -- nnm:accGradParameters(node.input, node.gradOutput, scale)
self:rethrowErrors(nnm, k, 'accGradParameters', node.input, node.gradOutput, scale)
end
end
self:rethrowErrors(nnm, k, 'accGradParameters', node.input, node.gradOutput, scale)
end
end