X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=62a88918e6436fdc68e97f31a6851631ccfb91df;hb=363ce48d64d1a036b86d29564bf6ad367126c2b1;hp=affc8cdcfae81d694e46c1e7afdcb2cc2d0656b4;hpb=5aee50805cfad1dd49bbf30b30fe65b05e03de78;p=picoclvr.git diff --git a/tasks.py b/tasks.py index affc8cd..62a8891 100755 --- a/tasks.py +++ b/tasks.py @@ -225,6 +225,10 @@ class PicoCLVR(Task): primer += [primer_descr + " "] * nb_per_primer result = self.tensorize(primer) + fill = result.new_full( + result.size()[:-1] + (self.height * self.width + 1,), self.t_nul + ) + result = torch.cat((result, fill), 1) ar_mask = (result == self.t_nul).long() masked_inplace_autoregression( model,