end
function DAG:addEdge(a, b)
+ self.sorted = nil
local pred, succ = self.pred, self.succ
if not pred[a] and not succ[a] then
self:add(a)
end
function DAG:setInput(i)
+ self.sorted = nil
if torch.type(i) == 'table' then
self.inputModules = i
for _, m in ipairs(i) do
end
function DAG:setOutput(o)
+ self.sorted = nil
if torch.type(o) == 'table' then
self.outputModules = o
for _, m in ipairs(o) do
end
end
-function DAG:order()
+function DAG:sort()
+ if self.sorted then
+ return
+ end
+
local distance = {}
for _, a in pairs(self.inputModules) do
end
function DAG:print()
+ self:sort()
+
for i, d in ipairs(self.sorted) do
print('#' .. i .. ' -> ' .. torch.type(d))
end
end
function DAG:updateOutput(input)
+ self:sort()
+
if #self.inputModules == 1 then
self.inputModules[1]:updateOutput(input)
else
return self.output
end
+
+function DAG:updateGradInput(input, gradOutput)
+ self:sort()
+end
+
+return DAG