Update.
[picoclvr.git] / tasks.py
index 463d94c..4d7e90e 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -809,9 +809,8 @@ class Expr(Task):
             nb_train_samples,
             nb_variables=nb_variables,
             length=sequence_length,
-            # length=2 * sequence_length,
-            # randomize_length=True,
         )
+
         test_sequences = expr.generate_sequences(
             nb_test_samples,
             nb_variables=nb_variables,
@@ -841,9 +840,8 @@ class Expr(Task):
         for batch in tqdm.tqdm(
             input.split(self.batch_size), dynamic_ncols=True, desc=desc
         ):
-            if split == "train":
-                last = (batch != self.filler).max(0).values.nonzero().max() + 3
-                batch = batch[:, :last]
+            last = (batch != self.filler).max(0).values.nonzero().max() + 3
+            batch = batch[:, :last]
             yield batch
 
     def vocabulary_size(self):
@@ -867,7 +865,8 @@ class Expr(Task):
 
             def compute_nb_correct(input):
                 result = input.clone()
-                ar_mask = (result == self.space).long().cumsum(dim=1).clamp(max=1)
+                s = (result == self.space).long()
+                ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
                 result = (1 - ar_mask) * result + ar_mask * self.filler
                 masked_inplace_autoregression(
                     model,
@@ -911,7 +910,7 @@ class Expr(Task):
                 test_nb_correct,
                 test_nb_delta,
                 test_nb_missed,
-            ) = compute_nb_correct(self.test_input[:1000])
+            ) = compute_nb_correct(self.test_input[:10000])
 
             logger(
                 f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"