Update.
[picoclvr.git] / main.py
diff --git a/main.py b/main.py
index 14b1bc3..314a961 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -45,9 +45,9 @@ parser.add_argument("--nb_epochs", type=int, default=None)
 
 parser.add_argument("--batch_size", type=int, default=None)
 
-parser.add_argument("--nb_train_samples", type=int, default=250000)
+parser.add_argument("--nb_train_samples", type=int, default=None)
 
-parser.add_argument("--nb_test_samples", type=int, default=10000)
+parser.add_argument("--nb_test_samples", type=int, default=None)
 
 parser.add_argument("--optim", type=str, default="adam")
 
@@ -113,7 +113,7 @@ parser.add_argument("--stack_nb_steps", type=int, default=100)
 
 parser.add_argument("--stack_nb_stacks", type=int, default=1)
 
-parser.add_argument("--stack_nb_values", type=int, default=10)
+parser.add_argument("--stack_nb_digits", type=int, default=1)
 
 ######################################################################
 
@@ -214,7 +214,7 @@ def masked_inplace_autoregression(
     # entropy[:,s]= p.xlogy(p).sum(1) / math.log(2)
     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
     if progress_bar_desc is not None:
-        tqdm.tqdm(
+        batches = tqdm.tqdm(
             batches,
             dynamic_ncols=True,
             desc=progress_bar_desc,
@@ -875,28 +875,28 @@ class TaskStack(Task):
         batch_size,
         nb_steps,
         nb_stacks,
-        nb_values,
+        nb_digits,
         device=torch.device("cpu"),
     ):
         self.batch_size = batch_size
         self.nb_steps = nb_steps
         self.nb_stacks = nb_stacks
-        self.nb_values = nb_values
+        self.nb_digits = nb_digits
         self.device = device
 
         self.train_input, self.train_stack_counts = stack.generate_sequences(
-            nb_train_samples, nb_steps, nb_stacks, nb_values, self.device
+            nb_train_samples, nb_steps, nb_stacks, nb_digits, self.device
         )
 
         self.test_input, self.test_stack_counts = stack.generate_sequences(
-            nb_test_samples, nb_steps, nb_stacks, nb_values, self.device
+            nb_test_samples, nb_steps, nb_stacks, nb_digits, self.device
         )
 
         mask = self.test_input.clone()
-        stack.remove_poped_values(mask,self.nb_stacks)
-        mask=(mask!=self.test_input)
+        stack.remove_popped_values(mask, self.nb_stacks, self.nb_digits)
+        mask = mask != self.test_input
         counts = self.test_stack_counts.flatten()[mask.flatten()]
-        counts=F.one_hot(counts).sum(0)
+        counts = F.one_hot(counts).sum(0)
         log_string(f"stack_count {counts}")
 
         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
@@ -923,19 +923,19 @@ class TaskStack(Task):
 
             def compute_nb_correct(input):
                 result = input.clone()
-                stack.remove_poped_values(result,self.nb_stacks)
+                stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
                 ar_mask = (result != input).long()
-                result *= 1 - ar_mask
-
                 masked_inplace_autoregression(
                     model, self.batch_size, result, ar_mask, device=self.device
                 )
 
-                nb_total = ar_mask.sum()
+                errors = ((result != input).long() * ar_mask).reshape(
+                    -1, 1 + self.nb_digits
+                )
+                ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
 
-                nb_correct = (
-                    (result == input).long() * ar_mask
-                ).sum()
+                nb_total = ar_mask.max(1).values.sum()
+                nb_correct = nb_total - errors.max(1).values.sum()
 
                 return nb_total, nb_correct
 
@@ -945,6 +945,24 @@ class TaskStack(Task):
                 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
             )
 
+            #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+            input = self.test_input[:10, :20]
+            result = input.clone()
+            stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
+            ar_mask = (result != input).long()
+            for n in range(result.size(0)):
+                log_string(
+                    f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
+                )
+            masked_inplace_autoregression(
+                model, self.batch_size, result, ar_mask, device=self.device
+            )
+            for n in range(result.size(0)):
+                log_string(
+                    f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
+                )
+            #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
             model.train(t)
 
 
@@ -1017,9 +1035,9 @@ elif args.task == "stack":
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
         batch_size=args.batch_size,
-        nb_steps = args.stack_nb_steps,
-        nb_stacks = args.stack_nb_stacks,
-        nb_values = args.stack_nb_values,
+        nb_steps=args.stack_nb_steps,
+        nb_stacks=args.stack_nb_stacks,
+        nb_digits=args.stack_nb_digits,
         device=device,
     )