projects
/
dagnn.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[dagnn.git]
/
dagnn.lua
diff --git
a/dagnn.lua
b/dagnn.lua
index
52913ad
..
4841843
100755
(executable)
--- a/
dagnn.lua
+++ b/
dagnn.lua
@@
-11,6
+11,7
@@
function DAG:__init()
end
function DAG:addEdge(a, b)
end
function DAG:addEdge(a, b)
+ self.sorted = nil
local pred, succ = self.pred, self.succ
if not pred[a] and not succ[a] then
self:add(a)
local pred, succ = self.pred, self.succ
if not pred[a] and not succ[a] then
self:add(a)
@@
-25,6
+26,7
@@
function DAG:addEdge(a, b)
end
function DAG:setInput(i)
end
function DAG:setInput(i)
+ self.sorted = nil
if torch.type(i) == 'table' then
self.inputModules = i
for _, m in ipairs(i) do
if torch.type(i) == 'table' then
self.inputModules = i
for _, m in ipairs(i) do
@@
-38,6
+40,7
@@
function DAG:setInput(i)
end
function DAG:setOutput(o)
end
function DAG:setOutput(o)
+ self.sorted = nil
if torch.type(o) == 'table' then
self.outputModules = o
for _, m in ipairs(o) do
if torch.type(o) == 'table' then
self.outputModules = o
for _, m in ipairs(o) do
@@
-50,7
+53,11
@@
function DAG:setOutput(o)
end
end
end
end
-function DAG:order()
+function DAG:sort()
+ if self.sorted then
+ return
+ end
+
local distance = {}
for _, a in pairs(self.inputModules) do
local distance = {}
for _, a in pairs(self.inputModules) do
@@
-81,12
+88,16
@@
function DAG:order()
end
function DAG:print()
end
function DAG:print()
+ self:sort()
+
for i, d in ipairs(self.sorted) do
print('#' .. i .. ' -> ' .. torch.type(d))
end
end
function DAG:updateOutput(input)
for i, d in ipairs(self.sorted) do
print('#' .. i .. ' -> ' .. torch.type(d))
end
end
function DAG:updateOutput(input)
+ self:sort()
+
if #self.inputModules == 1 then
self.inputModules[1]:updateOutput(input)
else
if #self.inputModules == 1 then
self.inputModules[1]:updateOutput(input)
else
@@
-120,3
+131,9
@@
function DAG:updateOutput(input)
return self.output
end
return self.output
end
+
+function DAG:updateGradInput(input, gradOutput)
+ self:sort()
+end
+
+return DAG