X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tensorstack.py;fp=tensorstack.py;h=0000000000000000000000000000000000000000;hb=046c2b8633a415e533ec14cb72d77845f0c3e85f;hp=584c12d463825438e5bf8d6ea1b1ccadf1a3894e;hpb=68aa86a6645dfef3f919aad5732a1a09db77bfae;p=picoclvr.git diff --git a/tensorstack.py b/tensorstack.py deleted file mode 100755 index 584c12d..0000000 --- a/tensorstack.py +++ /dev/null @@ -1,61 +0,0 @@ -#!/usr/bin/env python - -# Any copyright is dedicated to the Public Domain. -# https://creativecommons.org/publicdomain/zero/1.0/ - -# Written by Francois Fleuret - -from torch import Tensor - -import sys - - -def exception_hook(exc_type, exc_value, tb): - r"""Hacks the call stack message to show all the local variables in - case of RuntimeError or ValueError, and prints tensors as shape, - dtype and device. - - """ - - repr_orig = Tensor.__repr__ - Tensor.__repr__ = lambda x: f"{x.size()}:{x.dtype}:{x.device}" - - while tb: - print("--------------------------------------------------\n") - filename = tb.tb_frame.f_code.co_filename - name = tb.tb_frame.f_code.co_name - line_no = tb.tb_lineno - print(f' File "{filename}", line {line_no}, in {name}') - print(open(filename, "r").readlines()[line_no - 1]) - - if exc_type in {RuntimeError, ValueError}: - for n, v in tb.tb_frame.f_locals.items(): - print(f" {n} -> {v}") - - print() - tb = tb.tb_next - - Tensor.__repr__ = repr_orig - - print(f"{exc_type.__name__}: {exc_value}") - - -sys.excepthook = exception_hook - -###################################################################### - -if __name__ == "__main__": - import torch - - def dummy(a, b): - print(a @ b) - - def blah(a, b): - c = b + b - dummy(a, c) - - mmm = torch.randn(2, 3) - xxx = torch.randn(3) - # print(xxx@mmm) - blah(mmm, xxx) - blah(xxx, mmm)