Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 29 Oct 2024 07:21:47 +0000 (08:21 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 29 Oct 2024 07:21:47 +0000 (08:21 +0100)
grid.py

diff --git a/grid.py b/grid.py
index 12ecba1..d06802d 100755 (executable)
--- a/grid.py
+++ b/grid.py
@@ -11,9 +11,8 @@
 
 import math, re
 
-import torch, torchvision
+import torch
 
-from torch import nn
 from torch.nn import functional as F
 
 ######################################################################
@@ -49,7 +48,7 @@ class FormalGrid:
             nb += F.one_hot(
                 (self.row[:, s] * self.grid_width + self.col[:, s]).long(),
                 num_classes=self.grid_height * self.grid_width,
-            ).to(torch.uint8)
+            ).to(torch.int8)
         self.master_grid_set = nb.max(dim=1).values <= 1
 
     def new_grid_set(self, constraints=None):
@@ -102,15 +101,15 @@ class FormalGrid:
         elif match("([1-9]) is_right_of ([1-9])"):
             return self.col[:, a] > self.col[:, b]
 
-        elif match("([1-9]) ([1-9]) parallel_to_diagonal"):
+        elif match("([1-9]) ([1-9]) is_parallel_to_diagonal"):
             return (self.col[:, a] - self.col[:, b]).abs() == (
                 self.row[:, a] - self.row[:, b]
             ).abs()
 
-        elif match("([1-9]) ([1-9]) vertical"):
+        elif match("([1-9]) ([1-9]) is_vertical"):
             return self.col[:, a] == self.col[:, b]
 
-        elif match("([1-9]) ([1-9]) horizontal"):
+        elif match("([1-9]) ([1-9]) is_horizontal"):
             return self.row[:, a] == self.row[:, b]
 
         elif match("([1-9]) ([1-9]) ([1-9]) are_aligned"):
@@ -193,13 +192,22 @@ if __name__ == "__main__":
 
     grid = FormalGrid(grid_height=8, grid_width=8, nb_symbols=4, device=device)
 
+    # grid_set = grid.new_grid_set(
+    # [
+    # "1 2 3 form_a_right_angle",
+    # "2 3 4 form_a_right_angle",
+    # "3 4 1 form_a_right_angle",
+    # "2 is_equidistant_from 1 and 3",
+    # "1 is_above 4",
+    # ],
+    # )
+
     grid_set = grid.new_grid_set(
         [
-            "1 2 3 form_a_right_angle",
-            "2 3 4 form_a_right_angle",
-            "3 4 1 form_a_right_angle",
-            "2 is_equidistant_from 1 and 3",
-            "1 is_above 4",
+            "1 2 3 are_aligned",
+            "2 3 is_parallel_to_diagonal",
+            "4 1 is_vertical",
+            "3 4 is_horizontal",
         ],
     )