From 5ab8211805831629148d7b436b8770590f1987b0 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Sun, 6 Jun 2021 14:22:08 +0200 Subject: [PATCH] Initial commit. --- conv_chain.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100755 conv_chain.py diff --git a/conv_chain.py b/conv_chain.py new file mode 100755 index 0000000..04dfdfa --- /dev/null +++ b/conv_chain.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python + +import torch +from torch import nn + +###################################################################### + +def conv_chain(input_size, output_size, depth, cond): + if depth == 0: + if input_size == output_size: + return [ [ ] ] + else: + return [ ] + else: + r = [ ] + for kernel_size in range(1, input_size + 1): + for stride in range(1, input_size + 1): + if cond(kernel_size, stride): + n = (input_size - kernel_size) // stride + if n * stride + kernel_size == input_size: + q = conv_chain(n + 1, output_size, depth - 1, cond) + r += [ [ (kernel_size, stride) ] + u for u in q ] + return r + +###################################################################### + +# Example + +c = conv_chain( + input_size = 64, output_size = 8, + depth = 5, + cond = lambda k, s: k <= 4 and s <= 2 and s <= k//2 +) + +x = torch.rand(1, 1, 64) + +for m in c: + m = nn.Sequential(*[ nn.Conv1d(1, 1, l[0], l[1]) for l in m ]) + print(m) + print(x.size(), m(x).size()) + +###################################################################### -- 2.39.5