Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 23 Oct 2023 06:16:58 +0000 (08:16 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 23 Oct 2023 06:16:58 +0000 (08:16 +0200)
main.py
problems.py

diff --git a/main.py b/main.py
index 496a603..f4e4f5c 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -164,6 +164,8 @@ parser.add_argument("--expr_input_file", type=str, default=None)
 
 parser.add_argument("--mixing_hard", action="store_true", default=False)
 
+parser.add_argument("--mixing_deterministic_start", action="store_true", default=False)
+
 ######################################################################
 
 args = parser.parse_args()
@@ -416,7 +418,9 @@ elif args.task == "twotargets":
 
 elif args.task == "mixing":
     task = tasks.SandBox(
-        problem=problems.ProblemMixing(hard=args.mixing_hard),
+        problem=problems.ProblemMixing(
+            hard=args.mixing_hard, random_start=not args.mixing_deterministic_start
+        ),
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
         batch_size=args.batch_size,
index 9321194..ac16df4 100755 (executable)
@@ -289,36 +289,38 @@ class ProblemAddition(Problem):
 
 
 class ProblemMixing(Problem):
-    def __init__(self, height=4, width=4, nb_time_steps=9, hard=False):
+    def __init__(
+        self, height=4, width=4, nb_time_steps=9, hard=False, random_start=True
+    ):
         self.height = height
         self.width = width
         self.nb_time_steps = nb_time_steps
         self.hard = hard
+        self.random_start = random_start
 
     def start_random(self, nb):
         y = torch.arange(self.height * self.width).reshape(1, -1).expand(nb, -1)
 
-        # m = (torch.rand(y.size()).sort(dim=-1).indices < y.size(1) // 2).long()
+        if self.random_start:
+            i = (
+                torch.arange(self.height)
+                .reshape(1, -1, 1)
+                .expand(nb, self.height, self.width)
+            )
+            j = (
+                torch.arange(self.width)
+                .reshape(1, 1, -1)
+                .expand(nb, self.height, self.width)
+            )
 
-        i = (
-            torch.arange(self.height)
-            .reshape(1, -1, 1)
-            .expand(nb, self.height, self.width)
-        )
-        j = (
-            torch.arange(self.width)
-            .reshape(1, 1, -1)
-            .expand(nb, self.height, self.width)
-        )
+            ri = torch.randint(self.height, (nb,)).reshape(nb, 1, 1)
+            rj = torch.randint(self.width, (nb,)).reshape(nb, 1, 1)
 
-        ri = torch.randint(self.height, (nb,)).reshape(nb, 1, 1)
-        rj = torch.randint(self.width, (nb,)).reshape(nb, 1, 1)
+            m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1)
 
-        m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1)
+            y = y * m + self.height * self.width * (1 - m)
 
-        y = (y * m + self.height * self.width * (1 - m)).reshape(
-            nb, self.height, self.width
-        )
+        y = y.reshape(nb, self.height, self.width)
 
         return y