From: Francois Fleuret Date: Sun, 6 Jun 2021 12:30:34 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=ba45285b08782597aacd2764a7506b28a0fbf5d2;p=pytorch.git Update. --- diff --git a/conv_chain.py b/conv_chain.py index 04dfdfa..fa5d752 100755 --- a/conv_chain.py +++ b/conv_chain.py @@ -24,19 +24,21 @@ def conv_chain(input_size, output_size, depth, cond): ###################################################################### -# Example +if __name__ == "__main__": -c = conv_chain( - input_size = 64, output_size = 8, - depth = 5, - cond = lambda k, s: k <= 4 and s <= 2 and s <= k//2 -) + # Example -x = torch.rand(1, 1, 64) + c = conv_chain( + input_size = 64, output_size = 8, + depth = 5, + cond = lambda k, s: k <= 4 and s <= 2 and s <= k//2 + ) -for m in c: - m = nn.Sequential(*[ nn.Conv1d(1, 1, l[0], l[1]) for l in m ]) - print(m) - print(x.size(), m(x).size()) + x = torch.rand(1, 1, 64) + + for m in c: + model = nn.Sequential(*[ nn.Conv1d(1, 1, l[0], l[1]) for l in m ]) + print(model) + print(x.size(), model(x).size()) ######################################################################