Update.
[picoclvr.git] / tasks.py
index 04b8f84..4d7e90e 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -840,9 +840,8 @@ class Expr(Task):
         for batch in tqdm.tqdm(
             input.split(self.batch_size), dynamic_ncols=True, desc=desc
         ):
-            if split == "train":
-                last = (batch != self.filler).max(0).values.nonzero().max() + 3
-                batch = batch[:, :last]
+            last = (batch != self.filler).max(0).values.nonzero().max() + 3
+            batch = batch[:, :last]
             yield batch
 
     def vocabulary_size(self):
@@ -866,7 +865,8 @@ class Expr(Task):
 
             def compute_nb_correct(input):
                 result = input.clone()
-                ar_mask = (result == self.space).long().cumsum(dim=1).clamp(max=1)
+                s = (result == self.space).long()
+                ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
                 result = (1 - ar_mask) * result + ar_mask * self.filler
                 masked_inplace_autoregression(
                     model,