From: Francois Fleuret Date: Thu, 28 Jul 2022 06:50:21 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=6260a3593ac09a9bbdd9c85b23d78a71fa028acd;p=mygpt.git Update. --- diff --git a/main.py b/main.py index 1b011a2..427a83a 100755 --- a/main.py +++ b/main.py @@ -156,13 +156,14 @@ import picoclvr class TaskPicoCLVR(Task): + # Make a tensor from a list of strings def tensorize(self, descr): - descr = [ s.strip().split(' ') for s in descr ] - l = max([ len(s) for s in descr ]) - #descr = [ [ '' ] * (l - len(s)) + s for s in descr ] - descr = [ s + [ '' ] * (l - len(s)) for s in descr ] - t = [ [ self.token2id[u] for u in s ] for s in descr ] - return torch.tensor(t, device = self.device) + token_descr = [ s.strip().split(' ') for s in descr ] + l = max([ len(s) for s in token_descr ]) + #token_descr = [ [ '' ] * (l - len(s)) + s for s in token_descr ] + token_descr = [ s + [ '' ] * (l - len(s)) for s in token_descr ] + id_descr = [ [ self.token2id[u] for u in s ] for s in token_descr ] + return torch.tensor(id_descr, device = self.device) def __init__(self, batch_size, height, width, nb_colors = 5, @@ -281,6 +282,7 @@ class TaskWiki103(Task): self.vocab.set_default_index(self.vocab[ '' ]) + # makes a tensor from a list of list of tokens def tensorize(self, s): a = max(len(x) for x in s) return torch.tensor([ self.vocab(x + [ '' ] * (a - len(x))) for x in s ])