X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=sizer.py;h=dff36ebc6ba9f279e270339f8a6c548bb8140f86;hp=52620e88aac6600f5d5e80400a7162abc84991b2;hb=HEAD;hpb=8a6b6efe651113cc3e8eb13cb13059724948bc9d diff --git a/sizer.py b/sizer.py index 52620e8..5887e4a 100755 --- a/sizer.py +++ b/sizer.py @@ -1,10 +1,37 @@ #!/usr/bin/env python +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + import os, stat, sys import time import torch from torch import nn +###################################################################### + +if len(sys.argv) < 2: + print( + sys.argv[0] + + """ + +For example: + +(17, 3, 60, 80) +nn.Conv2d(3, 32, 3, padding = 1) +nn.MaxPool2d(2) +nn.Conv2d(32, 32, 3, padding = 1) +nn.MaxPool2d(2) +nn.Conv2d(32, 64, 3, padding = 1) +nn.MaxPool2d(5) +nn.Conv2d(64, 64, (3, 4))""" + ) + exit(1) + +###################################################################### + t = 0 while True: @@ -12,15 +39,15 @@ while True: t = os.stat(sys.argv[1])[stat.ST_MTIME] if t > pt: pt = t - os.system('clear') + os.system("clear") try: - temp = [l.strip('\n\r') for l in open(sys.argv[1], 'r').readlines()] + temp = [l.strip("\n\r") for l in open(sys.argv[1], "r").readlines()] x = torch.zeros(eval(temp.pop(0))) - print('-> ' + str(tuple(x.size()))) + print("-> " + str(tuple(x.size()))) for k in temp: - print(' ' + k) - x = eval(k + '(x)') - print('-> ' + str(tuple(x.size()))) + print(" " + k) + x = eval(k + "(x)") + print("-> " + str(tuple(x.size()))) except: - print('** Error **') + print("** Error **") time.sleep(1)