X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tensorstack.py;h=074588e891f8f3c6a0940c00c1f3c3388ff8695e;hb=763e18680a0ae2db64d27c9f2a2054a7403009c0;hp=3218be16b749c2fde21b887811ce730df690d34f;hpb=d92251f850bfcc181ec51ac1907e5fb5e4d693d5;p=beaver.git diff --git a/tensorstack.py b/tensorstack.py index 3218be1..074588e 100755 --- a/tensorstack.py +++ b/tensorstack.py @@ -11,9 +11,9 @@ import sys def exception_hook(exc_type, exc_value, tb): - r"""Hacks the call stack message to show all the local variables in - case of RuntimeError or ValueError, and prints tensors as shape, - dtype and device. + r"""Hacks the call stack message to show all the local variables + in case of relevant error, and prints tensors as shape, dtype and + device. """ @@ -28,7 +28,7 @@ def exception_hook(exc_type, exc_value, tb): print(f' File "{filename}", line {line_no}, in {name}') print(open(filename, "r").readlines()[line_no - 1]) - if exc_type in {RuntimeError, ValueError}: + if exc_type in {RuntimeError, ValueError, IndexError}: for n, v in tb.tb_frame.f_locals.items(): print(f" {n} -> {v}") @@ -45,7 +45,6 @@ sys.excepthook = exception_hook ###################################################################### if __name__ == "__main__": - import torch def dummy(a, b):