X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=463d94ca1ae4dc43a054c875eba2fe686f0686bd;hb=233f57347c9560aec2f3cbaf001a8efa56a0243b;hp=b277b96683a211cf07ea4a642a8f22c8ff6e1972;hpb=9c974541bec4b1aa67bb39bfff4ef128c36ae3d9;p=picoclvr.git diff --git a/tasks.py b/tasks.py index b277b96..463d94c 100755 --- a/tasks.py +++ b/tasks.py @@ -937,11 +937,12 @@ class Expr(Task): input = self.tensorize(sequences) result = input.clone() - ar_mask = (result == self.space).long().cumsum(dim=1).clamp(max=1) + s = (result == self.space).long() + ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1) result = (1 - ar_mask) * result + ar_mask * self.filler - # for n in range(result.size(0)): - # logger(f"test_before {self.seq2str(result[n])}") + for n in range(result.size(0)): + logger(f"test_before {self.seq2str(result[n])}") masked_inplace_autoregression( model,