X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=test-dagnn.lua;h=6c09f95298a072c7fe32ca3b4ad430755f6fb1c6;hb=da3a60ffa7e1a39e4d01b405c2d80d84c3722c2c;hp=262ea6fe3111830ab1f8270118b608725e124881;hpb=452781856eafd237579e5c90b6e345354df91b42;p=dagnn.git diff --git a/test-dagnn.lua b/test-dagnn.lua index 262ea6f..6c09f95 100755 --- a/test-dagnn.lua +++ b/test-dagnn.lua @@ -5,6 +5,21 @@ require 'nn' require 'dagnn' +function printTensorTable(t) + if torch.type(t) == 'table' then + for i, t in pairs(t) do + print('-- ELEMENT [' .. i .. '] --') + printTensorTable(t) + end + else + print(tostring(t)) + end +end + +-- torch.setnumthreads(params.nbThreads) +torch.setdefaulttensortype('torch.DoubleTensor') +torch.manualSeed(2) + a = nn.Linear(10, 10) b = nn.ReLU() c = nn.Linear(10, 3) @@ -12,19 +27,13 @@ d = nn.Linear(10, 3) e = nn.CMulTable() f = nn.Linear(3, 2) ---[[ - - a -----> b ---> c ----> e --- - \ / - \--> d ---/ - \ - \---> f --- -]]-- - -g = nn.DAG:new() +-- a -----> b ---> c ----> e --- +-- \ / +-- \--> d ---/ +-- \ +-- \---> f --- -g:setInput(a) -g:setOutput({ e, f }) +g = nn.DAG() g:addEdge(c, e) g:addEdge(a, b) @@ -33,11 +42,17 @@ g:addEdge(b, c) g:addEdge(b, d) g:addEdge(d, f) +g:setInput({a}) +g:setOutput({e,f}) + g:print() input = torch.Tensor(3, 10):uniform() -output = g:updateOutput(input) +output = g:updateOutput({input}) + +printTensorTable(output) + +---------------------------------------------------------------------- -print(output[1]) -print(output[2]) +-- gradInput = g:updateGradInput({ input }, output)