X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=82caebe284b6c966f3b1653aa396091edef75e46;hb=e6f60ce506f75da652d069a4926adfa4d92e3cbf;hp=cd0e1ead475e79d1f7ac5aef504d9f8395f99001;hpb=e36adb51bcc003b4189d92d6c5a31a1d86fe8837;p=mygpt.git diff --git a/main.py b/main.py index cd0e1ea..82caebe 100755 --- a/main.py +++ b/main.py @@ -170,6 +170,7 @@ class TaskPicoCLVR(Task): descr = [ s.strip().split(' ') for s in descr ] l = max([ len(s) for s in descr ]) + #descr = [ [ '' ] * (l - len(s)) + s for s in descr ] descr = [ s + [ '' ] * (l - len(s)) for s in descr ] return descr @@ -191,6 +192,7 @@ class TaskPicoCLVR(Task): self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ]) self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ]) + # Tokenize the train and test sets t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ] self.train_input = torch.tensor(t, device = self.device) t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ]