X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=dagnn.git;a=blobdiff_plain;f=test-dagnn.lua;h=6c09f95298a072c7fe32ca3b4ad430755f6fb1c6;hp=a45d6365d61a6183b7b1b49758cddb89272d7708;hb=da3a60ffa7e1a39e4d01b405c2d80d84c3722c2c;hpb=682b76200f755f5f16477e086056a86cafdea1cd diff --git a/test-dagnn.lua b/test-dagnn.lua index a45d636..6c09f95 100755 --- a/test-dagnn.lua +++ b/test-dagnn.lua @@ -5,6 +5,17 @@ require 'nn' require 'dagnn' +function printTensorTable(t) + if torch.type(t) == 'table' then + for i, t in pairs(t) do + print('-- ELEMENT [' .. i .. '] --') + printTensorTable(t) + end + else + print(tostring(t)) + end +end + -- torch.setnumthreads(params.nbThreads) torch.setdefaulttensortype('torch.DoubleTensor') torch.manualSeed(2) @@ -16,41 +27,32 @@ d = nn.Linear(10, 3) e = nn.CMulTable() f = nn.Linear(3, 2) ---[[ - - a -----> b ---> c ----> e --- - \ / - \--> d ---/ - \ - \---> f --- -]]-- +-- a -----> b ---> c ----> e --- +-- \ / +-- \--> d ---/ +-- \ +-- \---> f --- -g = nn.DAG:new() - -g:setInput(a) -g:setOutput({ e }) +g = nn.DAG() g:addEdge(c, e) g:addEdge(a, b) g:addEdge(d, e) g:addEdge(b, c) g:addEdge(b, d) --- g:addEdge(d, f) +g:addEdge(d, f) --- g = torch.load('dag.t7') +g:setInput({a}) +g:setOutput({e,f}) g:print() input = torch.Tensor(3, 10):uniform() -output = g:updateOutput(input) +output = g:updateOutput({input}) -if torch.type(output) == 'table' then - for i, t in pairs(output) do - print(tostring(i) .. ' -> ' .. tostring(t)) - end -else - print(tostring(output)) -end +printTensorTable(output) + +---------------------------------------------------------------------- -torch.save('dag.t7', g) +-- gradInput = g:updateGradInput({ input }, output)