class ProblemMixing(Problem):
- def __init__(self, height=3, width=3, nb_time_steps=12, hard=False):
+ def __init__(self, height=4, width=4, nb_time_steps=9, hard=False):
self.height = height
self.width = width
self.nb_time_steps = nb_time_steps
def start_random(self, nb):
y = torch.arange(self.height * self.width).reshape(1, -1).expand(nb, -1)
- m = (torch.rand(y.size()).sort(dim=-1).indices < y.size(1) // 2).long()
+ # m = (torch.rand(y.size()).sort(dim=-1).indices < y.size(1) // 2).long()
+
+ i = torch.arange(self.height).reshape(1,-1,1).expand(nb,self.height,self.width)
+ j = torch.arange(self.width).reshape(1,1,-1).expand(nb,self.height,self.width)
+
+ ri = torch.randint(self.height, (nb,)).reshape(nb,1,1)
+ rj = torch.randint(self.width, (nb,)).reshape(nb,1,1)
+
+ m = 1 - torch.logical_or(i==ri,j==rj).long().flatten(1)
y = (y * m + self.height * self.width * (1 - m)).reshape(
nb, self.height, self.width
return y
def start_error(self, x):
+ i = torch.arange(self.height).reshape(1,-1,1).expand_as(x)
+ j = torch.arange(self.width).reshape(1,1,-1).expand_as(x)
+
+ ri = (x == self.height * self.width).long().sum(dim=-1).argmax(-1).view(-1,1,1)
+ rj = (x == self.height * self.width).long().sum(dim=-2).argmax(-1).view(-1,1,1)
+
+ m = 1 - torch.logical_or(i==ri,j==rj).long().flatten(1)
+
x = x.flatten(1)
- u = torch.arange(self.height * self.width).reshape(1, -1)
- m = ((x - u).abs() == 0).long()
- d = (x - (m * u + (1 - m) * self.height * self.width)).abs().sum(-1) + (
- m.sum(dim=-1) != self.height * self.width // 2
- ).long()
+ u = torch.arange(self.height * self.width, device = x.device).reshape(1, -1)
+
+ d = (x - (m * u + (1 - m) * self.height * self.width)).abs().sum(-1)
return d
def moves(self, x):
return " | ".join(
[
" ".join(
- ["-".join([f"{x:02d}" for x in s]) for s in r.split(self.width)]
+ ["-".join([f"{x:02d}" if x < self.height * self.width else "**" for x in s]) for s in r.split(self.width)]
)
for r in seq.split(self.height * self.width)
]