X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=62a88918e6436fdc68e97f31a6851631ccfb91df;hb=f680fa1486b0a70c37f0951cedd7b5c56b5808bb;hp=3a4a1645ab2fba1408c0263c17e20a8f2e9022b2;hpb=07ce0d849569e234f2d7714d7438dfab29542610;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 3a4a164..62a8891 100755 --- a/tasks.py +++ b/tasks.py @@ -226,7 +226,7 @@ class PicoCLVR(Task): result = self.tensorize(primer) fill = result.new_full( - result.size()[:-1] + (self.height * self.width,), self.t_nul + result.size()[:-1] + (self.height * self.width + 1,), self.t_nul ) result = torch.cat((result, fill), 1) ar_mask = (result == self.t_nul).long()