X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=test-dagnn.lua;h=d7179cc1b5110edb2b6666ce8a1604f6fe6d2102;hb=31dc42fc93ed12491ceb10ef3bfc4296878380ee;hp=a0a81ab0e146988b3bc7859d4a90366b03e57b1f;hpb=be03a73e411d18082a2dd99bff5df45c085017ca;p=dagnn.git diff --git a/test-dagnn.lua b/test-dagnn.lua index a0a81ab..d7179cc 100755 --- a/test-dagnn.lua +++ b/test-dagnn.lua @@ -5,6 +5,21 @@ require 'nn' require 'dagnn' +function printTensorTable(t) + if torch.type(t) == 'table' then + for i, t in pairs(t) do + print('-- ELEMENT [' .. i .. '] --') + printTensorTable(t) + end + else + print(tostring(t)) + end +end + +-- torch.setnumthreads(params.nbThreads) +torch.setdefaulttensortype('torch.DoubleTensor') +torch.manualSeed(2) + a = nn.Linear(10, 10) b = nn.ReLU() c = nn.Linear(10, 3) @@ -12,19 +27,14 @@ d = nn.Linear(10, 3) e = nn.CMulTable() f = nn.Linear(3, 2) ---[[ - - a -----> b ---> c ----> e --- - \ / - \--> d ---/ - \ - \---> f --- -]]-- +-- a -----> b ---> c ----> e --- +-- \ / +-- \--> d ---/ +-- \ +-- \---> f --- -g = DAG:new() +g = nn.DAG() -g:setInput(a) -g:setOutput({ e, f }) g:addEdge(c, e) g:addEdge(a, b) g:addEdge(d, e) @@ -32,13 +42,22 @@ g:addEdge(b, c) g:addEdge(b, d) g:addEdge(d, f) -g:order() +g:setInput({{a}}) +g:setOutput({ e, f }) -g:print(graph) +g:print() input = torch.Tensor(3, 10):uniform() -output = g:updateOutput(input) +output = g:updateOutput({{ input }}) + +printTensorTable(output) + +---------------------------------------------------------------------- + +print('******************************************************************') +print('** updateGradInput ***********************************************') +print('******************************************************************') +gradInput = g:updateGradInput({{input}}, output) -print(output[1]) -print(output[2]) +printTensorTable(gradInput)