import math, re
-import torch, torchvision
+import torch
-from torch import nn
from torch.nn import functional as F
######################################################################
nb += F.one_hot(
(self.row[:, s] * self.grid_width + self.col[:, s]).long(),
num_classes=self.grid_height * self.grid_width,
- ).to(torch.uint8)
+ ).to(torch.int8)
self.master_grid_set = nb.max(dim=1).values <= 1
def new_grid_set(self, constraints=None):
elif match("([1-9]) is_right_of ([1-9])"):
return self.col[:, a] > self.col[:, b]
- elif match("([1-9]) ([1-9]) parallel_to_diagonal"):
+ elif match("([1-9]) ([1-9]) is_parallel_to_diagonal"):
return (self.col[:, a] - self.col[:, b]).abs() == (
self.row[:, a] - self.row[:, b]
).abs()
- elif match("([1-9]) ([1-9]) vertical"):
+ elif match("([1-9]) ([1-9]) is_vertical"):
return self.col[:, a] == self.col[:, b]
- elif match("([1-9]) ([1-9]) horizontal"):
+ elif match("([1-9]) ([1-9]) is_horizontal"):
return self.row[:, a] == self.row[:, b]
elif match("([1-9]) ([1-9]) ([1-9]) are_aligned"):
grid = FormalGrid(grid_height=8, grid_width=8, nb_symbols=4, device=device)
+ # grid_set = grid.new_grid_set(
+ # [
+ # "1 2 3 form_a_right_angle",
+ # "2 3 4 form_a_right_angle",
+ # "3 4 1 form_a_right_angle",
+ # "2 is_equidistant_from 1 and 3",
+ # "1 is_above 4",
+ # ],
+ # )
+
grid_set = grid.new_grid_set(
[
- "1 2 3 form_a_right_angle",
- "2 3 4 form_a_right_angle",
- "3 4 1 form_a_right_angle",
- "2 is_equidistant_from 1 and 3",
- "1 is_above 4",
+ "1 2 3 are_aligned",
+ "2 3 is_parallel_to_diagonal",
+ "4 1 is_vertical",
+ "3 4 is_horizontal",
],
)