Replaced error() with assert().
[dagnn.git] / dagnn.lua
index 5921c05..cf45233 100755 (executable)
--- a/dagnn.lua
+++ b/dagnn.lua
@@ -76,6 +76,10 @@ function DAG:putInOrder()
       end
    until nc == 0
 
+   for _, nnm in pairs(self.modules) do
+      assert(distance[nnm], 'Some modules are not connected to inputs')
+   end
+
    self.sorted = {}
    for m, d in pairs(distance) do
       table.insert(self.sorted, { distance = d, nnm = m })
@@ -142,12 +146,8 @@ function DAG:setInput(i)
    self.inputModules = i
    self:nestedApply(
       function(nnm)
-         if #self.node[nnm].succ == 0 then
-            error('Input modules must have outgoing  edges.')
-         end
-         if #self.node[nnm].pred > 0 then
-            error('Input modules cannot have incoming edges.')
-         end
+         assert(#self.node[nnm].succ > 0, 'Input modules must have outgoing edges.')
+         assert(#self.node[nnm].pred == 0, 'Input modules cannot have incoming edges.')
       end,
       self.inputModules
    )
@@ -158,12 +158,8 @@ function DAG:setOutput(o)
    self.outputModules = o
    self:nestedApply(
       function(nnm)
-         if #self.node[nnm].pred == 0 then
-            error('Output module must have incoming edges.')
-         end
-         if #self.node[nnm].succ > 0 then
-            error('Output module cannot have outgoing edges.')
-         end
+         assert(#self.node[nnm].pred > 0, 'Output module must have incoming edges.')
+         assert(#self.node[nnm].succ == 0, 'Output module cannot have outgoing edges.')
       end,
       self.outputModules
    )
@@ -292,9 +288,8 @@ function DAG:updateGradInput(input, gradOutput)
       if #pred == 1 then
          table.insert(self.node[pred[1]].gradInputSucc, nnm.gradInput)
       elseif #pred > 1 then
-         if not torch.type(nnm.gradInput) == 'table' then
-            error('Should have a table gradInput since it has multiple predecessors')
-         end
+         assert(torch.type(nnm.gradInput) == 'table',
+                'Should have a table gradInput since it has multiple predecessors')
          for n = 1, #pred do
             table.insert(self.node[node.pred[n]].gradInputSucc, nnm.gradInput[n])
          end