X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=4d7e90ec9b8ae6ce5103d564b07abe48e4e27563;hb=a8f039a9b491b1b4b47f6b9f8123c7261e758661;hp=04b8f84577097a7ef3de515fa90f6a461429e7e5;hpb=62e273047aee0a1d606fbe0312abc16a74d23906;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 04b8f84..4d7e90e 100755 --- a/tasks.py +++ b/tasks.py @@ -840,9 +840,8 @@ class Expr(Task): for batch in tqdm.tqdm( input.split(self.batch_size), dynamic_ncols=True, desc=desc ): - if split == "train": - last = (batch != self.filler).max(0).values.nonzero().max() + 3 - batch = batch[:, :last] + last = (batch != self.filler).max(0).values.nonzero().max() + 3 + batch = batch[:, :last] yield batch def vocabulary_size(self): @@ -866,7 +865,8 @@ class Expr(Task): def compute_nb_correct(input): result = input.clone() - ar_mask = (result == self.space).long().cumsum(dim=1).clamp(max=1) + s = (result == self.space).long() + ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1) result = (1 - ar_mask) * result + ar_mask * self.filler masked_inplace_autoregression( model,