# Written by Francois Fleuret <francois@fleuret.org>
-from torch import is_tensor
+from torch import is_tensor, Tensor
import sys
def exception_hook(exc_type, exc_value, tb):
-# tb = tb.tb_next
+ repr_orig=Tensor.__repr__
+ Tensor.__repr__=lambda x: f'{x.size()}:{x.dtype}:{x.device}'
while tb:
- # x=tb.tb_frame #.f_code
- # for field in dir(x):
- # print(f'@@@ {field} {getattr(x, field)}')
print('--------------------------------------------------')
filename = tb.tb_frame.f_code.co_filename
name = tb.tb_frame.f_code.co_name
print(f' File "{filename}", line {line_no}, in {name}')
print(open(filename, 'r').readlines()[line_no-1], end='')
- local_vars = tb.tb_frame.f_locals
-
- for n,v in local_vars.items():
- if is_tensor(v):
- print(f' {n} -> {tuple(v.size())}:{v.dtype}:{v.device}')
- else:
+ if exc_type is RuntimeError:
+ for n,v in tb.tb_frame.f_locals.items():
print(f' {n} -> {v}')
tb = tb.tb_next
+ Tensor.__repr__=repr_orig
+
print(f'{exc_type.__name__}: {exc_value}')
sys.excepthook = exception_hook
#print(xxx@mmm)
blah(mmm,xxx)
blah(xxx,mmm)
-