X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=stack.py;h=0838bee9e025340f3058fedb7478ce14200ac788;hb=3814647c468e48a12543519b7ce7e584936e24ee;hp=27de2040b98c342cf48468b27d9a0e5c0b834ba3;hpb=d37e52d43cec4e3044c4a986f5be2a2553b4e1e1;p=pytorch.git diff --git a/stack.py b/stack.py index 27de204..0838bee 100755 --- a/stack.py +++ b/stack.py @@ -5,35 +5,31 @@ # Written by Francois Fleuret -from torch import is_tensor +from torch import Tensor import sys def exception_hook(exc_type, exc_value, tb): - tb = tb.tb_next + repr_orig=Tensor.__repr__ + Tensor.__repr__=lambda x: f'{x.size()}:{x.dtype}:{x.device}' while tb: - #x=tb.tb_frame.f_code - # for field in dir(x): - # print(f'@@@ {field} {getattr(x, field)}') - + print('--------------------------------------------------') filename = tb.tb_frame.f_code.co_filename name = tb.tb_frame.f_code.co_name line_no = tb.tb_lineno print(f' File "{filename}", line {line_no}, in {name}') print(open(filename, 'r').readlines()[line_no-1], end='') - local_vars = tb.tb_frame.f_locals - - for n,v in local_vars.items(): - if is_tensor(v): - print(f' {n} -> {v.size()}:{v.dtype}:{v.device}') - else: + if exc_type is RuntimeError: + for n,v in tb.tb_frame.f_locals.items(): print(f' {n} -> {v}') tb = tb.tb_next + Tensor.__repr__=repr_orig + print(f'{exc_type.__name__}: {exc_value}') sys.excepthook = exception_hook @@ -51,8 +47,8 @@ if __name__ == '__main__': c=b+b dummy(a,c) - m=torch.randn(2,3) - x=torch.randn(3) - blah(m,x) - blah(x,m) - + mmm=torch.randn(2,3) + xxx=torch.randn(3) + #print(xxx@mmm) + blah(mmm,xxx) + blah(xxx,mmm)