X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=test-dagnn.lua;h=3b1e66a2f2ef13d602c24e20e1f5f6841e01cf39;hb=60568def49e4c624e54f53b4be5783d6cfbe1ea9;hp=a45d6365d61a6183b7b1b49758cddb89272d7708;hpb=682b76200f755f5f16477e086056a86cafdea1cd;p=dagnn.git diff --git a/test-dagnn.lua b/test-dagnn.lua index a45d636..3b1e66a 100755 --- a/test-dagnn.lua +++ b/test-dagnn.lua @@ -5,6 +5,17 @@ require 'nn' require 'dagnn' +function printTensorTable(t) + if torch.type(t) == 'table' then + for i, t in pairs(t) do + print('-- ELEMENT [' .. i .. '] --') + printTensorTable(t) + end + else + print(tostring(t)) + end +end + -- torch.setnumthreads(params.nbThreads) torch.setdefaulttensortype('torch.DoubleTensor') torch.manualSeed(2) @@ -16,41 +27,32 @@ d = nn.Linear(10, 3) e = nn.CMulTable() f = nn.Linear(3, 2) ---[[ - - a -----> b ---> c ----> e --- - \ / - \--> d ---/ - \ - \---> f --- -]]-- +-- a -----> b ---> c ----> e --- +-- \ / +-- \--> d ---/ +-- \ +-- \---> f --- -g = nn.DAG:new() - -g:setInput(a) -g:setOutput({ e }) +g = nn.DAG() g:addEdge(c, e) g:addEdge(a, b) g:addEdge(d, e) g:addEdge(b, c) g:addEdge(b, d) --- g:addEdge(d, f) +g:addEdge(d, f) --- g = torch.load('dag.t7') +g:setInput({a}) +g:setOutput({e, f}) g:print() input = torch.Tensor(3, 10):uniform() -output = g:updateOutput(input) +output = g:updateOutput({input}) -if torch.type(output) == 'table' then - for i, t in pairs(output) do - print(tostring(i) .. ' -> ' .. tostring(t)) - end -else - print(tostring(output)) -end +printTensorTable(output) + +---------------------------------------------------------------------- -torch.save('dag.t7', g) +-- gradInput = g:updateGradInput({ input }, output)