Update.
[picoclvr.git] / tasks.py
index 912b405..b277b96 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -748,18 +748,21 @@ class Stack(Task):
             result = input.clone()
             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
             ar_mask = (result != input).long()
-            for n in range(result.size(0)):
-                logger(
-                    f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
-                )
-                masked_inplace_autoregression(
-                    model,
-                    self.batch_size,
-                    result,
-                    ar_mask,
-                    deterministic_synthesis,
-                    device=self.device,
-                )
+
+            # for n in range(result.size(0)):
+            # logger(
+            # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
+            # )
+
+            masked_inplace_autoregression(
+                model,
+                self.batch_size,
+                result,
+                ar_mask,
+                deterministic_synthesis,
+                device=self.device,
+            )
+
             for n in range(result.size(0)):
                 logger(
                     f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
@@ -936,21 +939,24 @@ class Expr(Task):
             result = input.clone()
             ar_mask = (result == self.space).long().cumsum(dim=1).clamp(max=1)
             result = (1 - ar_mask) * result + ar_mask * self.filler
-            for n in range(result.size(0)):
-                logger(f"test_before {self.seq2str(result[n])}")
-                masked_inplace_autoregression(
-                    model,
-                    self.batch_size,
-                    result,
-                    ar_mask,
-                    deterministic_synthesis,
-                    device=self.device,
-                )
+
+            # for n in range(result.size(0)):
+            # logger(f"test_before {self.seq2str(result[n])}")
+
+            masked_inplace_autoregression(
+                model,
+                self.batch_size,
+                result,
+                ar_mask,
+                deterministic_synthesis,
+                device=self.device,
+            )
+
             correct = (1 - ar_mask) * self.space + ar_mask * input
             for n in range(result.size(0)):
                 comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
                 logger(f"test_after  {self.seq2str(result[n])} {comment}")
-                logger(f"correct     {self.seq2str(correct[n])}")
+                logger(f"truth       {self.seq2str(correct[n])}")
             ##############################################################
 
             model.train(t)