From 96429a6891d09c994a13a0b6969c7f82c45945a7 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 29 Oct 2024 08:21:47 +0100 Subject: [PATCH] Update. --- grid.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/grid.py b/grid.py index 12ecba1..d06802d 100755 --- a/grid.py +++ b/grid.py @@ -11,9 +11,8 @@ import math, re -import torch, torchvision +import torch -from torch import nn from torch.nn import functional as F ###################################################################### @@ -49,7 +48,7 @@ class FormalGrid: nb += F.one_hot( (self.row[:, s] * self.grid_width + self.col[:, s]).long(), num_classes=self.grid_height * self.grid_width, - ).to(torch.uint8) + ).to(torch.int8) self.master_grid_set = nb.max(dim=1).values <= 1 def new_grid_set(self, constraints=None): @@ -102,15 +101,15 @@ class FormalGrid: elif match("([1-9]) is_right_of ([1-9])"): return self.col[:, a] > self.col[:, b] - elif match("([1-9]) ([1-9]) parallel_to_diagonal"): + elif match("([1-9]) ([1-9]) is_parallel_to_diagonal"): return (self.col[:, a] - self.col[:, b]).abs() == ( self.row[:, a] - self.row[:, b] ).abs() - elif match("([1-9]) ([1-9]) vertical"): + elif match("([1-9]) ([1-9]) is_vertical"): return self.col[:, a] == self.col[:, b] - elif match("([1-9]) ([1-9]) horizontal"): + elif match("([1-9]) ([1-9]) is_horizontal"): return self.row[:, a] == self.row[:, b] elif match("([1-9]) ([1-9]) ([1-9]) are_aligned"): @@ -193,13 +192,22 @@ if __name__ == "__main__": grid = FormalGrid(grid_height=8, grid_width=8, nb_symbols=4, device=device) + # grid_set = grid.new_grid_set( + # [ + # "1 2 3 form_a_right_angle", + # "2 3 4 form_a_right_angle", + # "3 4 1 form_a_right_angle", + # "2 is_equidistant_from 1 and 3", + # "1 is_above 4", + # ], + # ) + grid_set = grid.new_grid_set( [ - "1 2 3 form_a_right_angle", - "2 3 4 form_a_right_angle", - "3 4 1 form_a_right_angle", - "2 is_equidistant_from 1 and 3", - "1 is_above 4", + "1 2 3 are_aligned", + "2 3 is_parallel_to_diagonal", + "4 1 is_vertical", + "3 4 is_horizontal", ], ) -- 2.39.5