X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=stack.py;fp=stack.py;h=0000000000000000000000000000000000000000;hb=b564b20368674d10ebefb56a99b7732827820d3b;hp=544306c2516307bac6bc47936d96766db8c84400;hpb=c17cb30d6d5bfa2f450dd9cc3d2d931fce9cbdab;p=pytorch.git diff --git a/stack.py b/stack.py deleted file mode 100755 index 544306c..0000000 --- a/stack.py +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env python - -# Any copyright is dedicated to the Public Domain. -# https://creativecommons.org/publicdomain/zero/1.0/ - -# Written by Francois Fleuret - -from torch import Tensor - -import sys - -def exception_hook(exc_type, exc_value, tb): - r'''Hacks the call stack message in case of RuntimeError to show all - the local variables, and indicate for every tensor its shape, - dtype and device. - - ''' - - repr_orig=Tensor.__repr__ - Tensor.__repr__=lambda x: f'{x.size()}:{x.dtype}:{x.device}' - - while tb: - print('--------------------------------------------------') - filename = tb.tb_frame.f_code.co_filename - name = tb.tb_frame.f_code.co_name - line_no = tb.tb_lineno - print(f' File "{filename}", line {line_no}, in {name}') - print(open(filename, 'r').readlines()[line_no-1], end='') - - if exc_type is RuntimeError: - for n,v in tb.tb_frame.f_locals.items(): - print(f' {n} -> {v}') - - tb = tb.tb_next - - Tensor.__repr__=repr_orig - - print(f'{exc_type.__name__}: {exc_value}') - -sys.excepthook = exception_hook - -###################################################################### - -if __name__ == '__main__': - - import torch - - def dummy(a,b): - print(a@b) - - def blah(a,b): - c=b+b - dummy(a,c) - - mmm=torch.randn(2,3) - xxx=torch.randn(3) - #print(xxx@mmm) - blah(mmm,xxx) - blah(xxx,mmm)