X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=dagnn.git;a=blobdiff_plain;f=dagnn.lua;h=b82398c0de429bd19875776fa465222a84504bbe;hp=ca2692692379f7d19db4aa82e7c6af7a48bbd629;hb=HEAD;hpb=78671207db483567021a935d0738ba85d8b16551 diff --git a/dagnn.lua b/dagnn.lua index ca26926..b82398c 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -69,6 +69,7 @@ function DAG:putInOrder() local nc local nl = 0 repeat + assert(nl < #self.modules, 'Cycle detected in the graph.') nc = 0 for nnma, node in pairs(self.node) do for _, nnmb in pairs(node.succ) do @@ -78,12 +79,11 @@ function DAG:putInOrder() end end end - assert(nl < #self.modules, 'Cycle detected in the graph.') nl = nl + 1 until nc == 0 for _, nnm in pairs(self.modules) do - assert(distance[nnm], 'Some modules are not connected to inputs') + assert(distance[nnm], 'Some modules are not connected to inputs.') end self.sorted = {} @@ -148,6 +148,10 @@ function DAG:connect(...) end end +function DAG:setLabel(nnm, label) + self.node[nnm].label = label +end + function DAG:setInput(i) self.sorted = nil self.inputModules = i @@ -176,7 +180,11 @@ function DAG:print() self:putInOrder() for i, d in ipairs(self.sorted) do - print('#' .. i .. ' -> ' .. torch.type(d)) + local decoration = '' + if self.node[d].label then + decoration = ' [' .. self.node[d].label .. ']' + end + print('#' .. i .. ' -> ' .. torch.type(d) .. decoration) end end @@ -211,7 +219,7 @@ function DAG:saveDot(filename) file:write( ' ' .. node.index - .. ' [shape=box,label=\"' .. torch.type(nnmb) .. '\"]' + .. ' [shape=box,label=\"' .. (self.node[nnmb].label or torch.type(nnmb)) .. '\"]' .. '\n' ) @@ -280,7 +288,7 @@ function DAG:updateOutput(input) end function DAG:updateGradInput(input, gradOutput) - assert(self.sorted, 'There has been a structure change before a DAG:updateGradInput') + assert(self.sorted, 'There has been a structure change before a DAG:updateGradInput.') self:nestedApply( function(nnm, go) @@ -315,7 +323,7 @@ function DAG:updateGradInput(input, gradOutput) table.insert(self.node[pred[1]].gradInputSucc, nnm.gradInput) elseif #pred > 1 then assert(torch.type(nnm.gradInput) == 'table', - 'Should have a table gradInput since it has multiple predecessors') + 'Should have a table gradInput since it has multiple predecessors.') for n = 1, #pred do table.insert(self.node[pred[n]].gradInputSucc, nnm.gradInput[n]) end @@ -331,7 +339,7 @@ function DAG:updateGradInput(input, gradOutput) end function DAG:accGradParameters(input, gradOutput, scale) - assert(self.sorted, 'There has been a structure change before a DAG:accGradParameters') + assert(self.sorted, 'There has been a structure change before a DAG:accGradParameters.') self:nestedApply( function(nnm, go) self.node[nnm].gradOutput = go end,