for kernel_size in range(1, input_size + 1):
for stride in range(1, input_size):
if cond(remain_depth, kernel_size, stride):
n = (input_size - kernel_size) // stride + 1
for kernel_size in range(1, input_size + 1):
for stride in range(1, input_size):
if cond(remain_depth, kernel_size, stride):
n = (input_size - kernel_size) // stride + 1