From: François Fleuret Date: Thu, 6 Jul 2023 09:31:37 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=046c2b8633a415e533ec14cb72d77845f0c3e85f;p=picoclvr.git Update. --- diff --git a/main.py b/main.py index 5b49468..df3f154 100755 --- a/main.py +++ b/main.py @@ -14,7 +14,7 @@ import torch, torchvision from torch import nn from torch.nn import functional as F -import mygpt, tasks, tensorstack +import mygpt, tasks ###################################################################### @@ -384,16 +384,16 @@ train_set_perplexity = math.exp(entropy) train_examples = {} for input in task.batches(split="train"): - assert input.dim()==2 and input.dtype==torch.int64 + assert input.dim() == 2 and input.dtype == torch.int64 for x in input: - train_examples[x.sum().item()]=x + train_examples[x.sum().item()] = x for input in task.batches(split="test"): - assert input.dim()==2 and input.dtype==torch.int64 + assert input.dim() == 2 and input.dtype == torch.int64 for x in input: y = train_examples.get(x.sum().item()) if y is not None: - assert x.size() != y.size() or (x-y).abs().sum() > 0 + assert x.size() != y.size() or (x - y).abs().sum() > 0 del train_examples diff --git a/tensorstack.py b/tensorstack.py deleted file mode 100755 index 584c12d..0000000 --- a/tensorstack.py +++ /dev/null @@ -1,61 +0,0 @@ -#!/usr/bin/env python - -# Any copyright is dedicated to the Public Domain. -# https://creativecommons.org/publicdomain/zero/1.0/ - -# Written by Francois Fleuret - -from torch import Tensor - -import sys - - -def exception_hook(exc_type, exc_value, tb): - r"""Hacks the call stack message to show all the local variables in - case of RuntimeError or ValueError, and prints tensors as shape, - dtype and device. - - """ - - repr_orig = Tensor.__repr__ - Tensor.__repr__ = lambda x: f"{x.size()}:{x.dtype}:{x.device}" - - while tb: - print("--------------------------------------------------\n") - filename = tb.tb_frame.f_code.co_filename - name = tb.tb_frame.f_code.co_name - line_no = tb.tb_lineno - print(f' File "{filename}", line {line_no}, in {name}') - print(open(filename, "r").readlines()[line_no - 1]) - - if exc_type in {RuntimeError, ValueError}: - for n, v in tb.tb_frame.f_locals.items(): - print(f" {n} -> {v}") - - print() - tb = tb.tb_next - - Tensor.__repr__ = repr_orig - - print(f"{exc_type.__name__}: {exc_value}") - - -sys.excepthook = exception_hook - -###################################################################### - -if __name__ == "__main__": - import torch - - def dummy(a, b): - print(a @ b) - - def blah(a, b): - c = b + b - dummy(a, c) - - mmm = torch.randn(2, 3) - xxx = torch.randn(3) - # print(xxx@mmm) - blah(mmm, xxx) - blah(xxx, mmm)