x = x.flatten(1)
u = torch.arange(self.height * self.width).reshape(1, -1)
m = ((x - u).abs() == 0).long()
- d = (x - (m * u + (1-m) * self.height * self.width)).abs().sum(-1) + (
+ d = (x - (m * u + (1 - m) * self.height * self.width)).abs().sum(-1) + (
m.sum(dim=-1) != self.height * self.width // 2
).long()
return d
####################
if __name__ == "__main__":
- p = ProblemMixing(width=4, hard=True)
+ p = ProblemMixing()
s, m = p.generate_sequences(10000)
for x in s[:5]:
print(p.seq2str(x))