Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 26 Oct 2024 21:53:32 +0000 (23:53 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 26 Oct 2024 21:53:32 +0000 (23:53 +0200)
grid.py
tinymnist.py

diff --git a/grid.py b/grid.py
index 991c44f..7168491 100755 (executable)
--- a/grid.py
+++ b/grid.py
@@ -4,7 +4,6 @@
 # https://creativecommons.org/publicdomain/zero/1.0/
 
 
-
 # Written by Francois Fleuret <francois@fleuret.org>
 
 # This code implement a simple system to manipulate formal
@@ -72,36 +71,36 @@ class FormalGrid:
             else:
                 return False
 
-        if match("([1-9]) top"):
+        if match("([1-9]) is_in_top_half"):
             (a,) = g[0]
-            return self.row[:, a] < self.grid_height // 4
-        elif match("([1-9]) bottom"):
+            return self.row[:, a] < self.grid_height // 2
+        elif match("([1-9]) is_in_bottom_half"):
             (a,) = g[0]
-            return self.row[:, a] >= (self.grid_height * 3) // 4
-        elif match("([1-9]) left"):
+            return self.row[:, a] >= self.grid_height // 2
+        elif match("([1-9]) is_on_left_side"):
             (a,) = g[0]
-            return self.col[:, a] < self.grid_width // 4
-        elif match("([1-9]) right"):
+            return self.col[:, a] < self.grid_width // 2
+        elif match("([1-9]) is_on_right_side"):
             (a,) = g[0]
-            return self.col[:, a] >= (self.grid_width * 3) // 4
+            return self.col[:, a] >= self.grid_width // 2
         elif match("([1-9]) next_to ([1-9])"):
             a, b = g[0]
             return (self.row[:, a] - self.row[:, b]).abs() + (
                 self.col[:, a] - self.col[:, b]
             ).abs() <= 1
-        elif match("([1-9]) below_of ([1-9])"):
+        elif match("([1-9]) is_below ([1-9])"):
             a, b = g[0]
             return self.row[:, a] > self.row[:, b]
-        elif match("([1-9]) above ([1-9])"):
+        elif match("([1-9]) is_above ([1-9])"):
             a, b = g[0]
             return self.row[:, a] < self.row[:, b]
-        elif match("([1-9]) left_of ([1-9])"):
+        elif match("([1-9]) is_left_of ([1-9])"):
             a, b = g[0]
             return self.col[:, a] < self.col[:, b]
-        elif match("([1-9]) right_of ([1-9])"):
+        elif match("([1-9]) is_right_of ([1-9])"):
             a, b = g[0]
             return self.col[:, a] > self.col[:, b]
-        elif match("([1-9]) ([1-9]) diagonal"):
+        elif match("([1-9]) ([1-9]) parallel_to_diagonal"):
             a, b = g[0]
             return (self.col[:, a] - self.col[:, b]).abs() == (
                 self.row[:, a] - self.row[:, b]
@@ -113,7 +112,7 @@ class FormalGrid:
             a, b = g[0]
             return self.row[:, a] == self.row[:, b]
 
-        elif match("([1-9]) ([1-9]) ([1-9]) aligned"):
+        elif match("([1-9]) ([1-9]) ([1-9]) are_aligned"):
             a, b, c = g[0]
             return (self.col[:, a] - self.col[:, b]) * (
                 self.row[:, a] - self.row[:, c]
@@ -129,7 +128,15 @@ class FormalGrid:
                 & (self.row[:, a] + self.row[:, c] == 2 * self.row[:, b])
             )
 
-        elif match("([1-9]) further_away_from ([1-9]) than ([1-9])"):
+        elif match("([1-9]) is_equidistant_from ([1-9]) and ([1-9])"):
+            a, b, c = g[0]
+            return (self.col[:, a] - self.col[:, b]) ** 2 + (
+                self.row[:, a] - self.row[:, b]
+            ) ** 2 == (self.col[:, a] - self.col[:, c]) ** 2 + (
+                self.row[:, a] - self.row[:, c]
+            ) ** 2
+
+        elif match("([1-9]) is_further_away_from ([1-9]) than ([1-9])"):
             a, b, c = g[0]
             return (self.col[:, a] - self.col[:, b]) ** 2 + (
                 self.row[:, a] - self.row[:, b]
@@ -137,7 +144,7 @@ class FormalGrid:
                 self.row[:, a] - self.row[:, c]
             ) ** 2
 
-        elif match("([1-9]) ([1-9]) ([1-9]) right_angle"):
+        elif match("([1-9]) ([1-9]) ([1-9]) make_right_angle"):
             a, b, c = g[0]
             return (self.col[:, a] - self.col[:, b]) * (
                 self.col[:, c] - self.col[:, b]
@@ -185,25 +192,22 @@ class FormalGrid:
 
 ######################################################################
 
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+if __name__ == "__main__":
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
-grid = FormalGrid(device=device)
+    grid = FormalGrid(grid_height=8, grid_width=8, nb_symbols=4, device=device)
 
-grid_set = grid.new_grid_set(
-    [
-        "4 top",
-        "4 right",
-        "1 top",
-        "1 left",
-        "1 left_of 2",
-        "2 left_of 3",
-        "1 2 4 right_angle",
-        "1 2 3 aligned",
-        "2 further_away_from 3 than 4",
-    ],
-)
+    grid_set = grid.new_grid_set(
+        [
+            "1 2 3 make_right_angle",
+            "2 3 4 make_right_angle",
+            "3 4 1 make_right_angle",
+            "2 is_equidistant_from 1 and 3",
+            "1 is_above 4",
+        ],
+    )
 
-print(f"There are {grid_set.long().sum().item()} configurations")
+    print(f"There are {grid_set.long().sum().item()} configurations")
 
-for v in grid.views(grid_set):
-    print(v)
+    for v in grid.views(grid_set):
+        print(v)
index 896477e..f662be6 100755 (executable)
@@ -70,14 +70,14 @@ test_input.sub_(mu).div_(std)
 start_time = time.perf_counter()
 
 for k in range(nb_epochs):
-    acc_loss = 0.0
+    acc_train_loss = 0.0
 
     for input, targets in zip(
         train_input.split(batch_size), train_targets.split(batch_size)
     ):
         output = model(input)
         loss = criterion(output, targets)
-        acc_loss += loss.item()
+        acc_train_loss += loss.item() * input.size(0)
 
         optimizer.zero_grad()
         loss.backward()
@@ -92,6 +92,8 @@ for k in range(nb_epochs):
     test_error = nb_test_errors / test_input.size(0)
     duration = time.perf_counter() - start_time
 
-    print(f"loss {k} {duration:.02f}s {acc_loss:.02f} {test_error*100:.02f}%")
+    print(
+        f"loss {k} {duration:.02f}s {acc_train_loss/train_input.size(0):.02f} {test_error*100:.02f}%"
+    )
 
 ######################################################################