X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=dagnn.lua;h=1b467e720542469d45228b1dbc8a8fd0b021f6ad;hb=e5030cca047eed4b8c5db172fc52e893b1b1d843;hp=a6414b3b569d92c28af9387aa0d6c2f31d8e63ef;hpb=da3a60ffa7e1a39e4d01b405c2d80d84c3722c2c;p=dagnn.git diff --git a/dagnn.lua b/dagnn.lua index a6414b3..1b467e7 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -46,8 +46,11 @@ function DAG:setInput(i) self.inputModules = i self:applyOnModules( function(m) - if (not self.succ[m] or #self.succ[m] == 0) or (self.pred[m] and #self.pred[m] > 0) then - error('Invalid input edges.') + if not self.succ[m] or #self.succ[m] == 0 then + error('Input modules must have outgoing edges.') + end + if self.pred[m] and #self.pred[m] > 0 then + error('Input modules cannog have incoming edges.') end end, self.inputModules @@ -59,8 +62,11 @@ function DAG:setOutput(o) self.outputModules = o self:applyOnModules( function(m) - if (not self.pred[m] or #self.pred[m] == 0) or (self.succ[m] and #self.succ[m] > 0) then - error('Invalid output edges.') + if not self.pred[m] or #self.pred[m] == 0 then + error('Output module must have incoming edges.') + end + if self.succ[m] and #self.succ[m] > 0 then + error('Output module cannot have outgoing edges.') end end, self.outputModules @@ -134,7 +140,10 @@ end function DAG:updateGradInput(input, gradOutput) self:sort() - self:applyOnModules(function(m, i, go) m:updateGradInput(i, go) end, self.outputModules, input, gradOutput) + self:applyOnModules( + function(m, i, go) m:updateGradInput(i, go) end, + self.outputModules, input, gradOutput + ) for k = self.sorted, 1, -1 do local m = sorted[k]