X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tensorstack.py;h=c9a6c2f301ad50b8ea75ff05b01dcf89b8578db9;hb=ca897077ed89fbc3c7e8d812ad262146a0c72b71;hp=544306c2516307bac6bc47936d96766db8c84400;hpb=b564b20368674d10ebefb56a99b7732827820d3b;p=pytorch.git diff --git a/tensorstack.py b/tensorstack.py index 544306c..c9a6c2f 100755 --- a/tensorstack.py +++ b/tensorstack.py @@ -10,8 +10,8 @@ from torch import Tensor import sys def exception_hook(exc_type, exc_value, tb): - r'''Hacks the call stack message in case of RuntimeError to show all - the local variables, and indicate for every tensor its shape, + r'''Hacks the call stack message to show all the local variables in + case of RuntimeError or ValueError, and prints tensors as shape, dtype and device. ''' @@ -20,17 +20,18 @@ def exception_hook(exc_type, exc_value, tb): Tensor.__repr__=lambda x: f'{x.size()}:{x.dtype}:{x.device}' while tb: - print('--------------------------------------------------') + print('--------------------------------------------------\n') filename = tb.tb_frame.f_code.co_filename name = tb.tb_frame.f_code.co_name line_no = tb.tb_lineno print(f' File "{filename}", line {line_no}, in {name}') - print(open(filename, 'r').readlines()[line_no-1], end='') + print(open(filename, 'r').readlines()[line_no-1]) - if exc_type is RuntimeError: + if exc_type in { RuntimeError, ValueError }: for n,v in tb.tb_frame.f_locals.items(): print(f' {n} -> {v}') + print() tb = tb.tb_next Tensor.__repr__=repr_orig