X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=grids.py;h=7aec62c16dd26f88ea9ea9d4938e099646f4f1ba;hb=12c775dcbd3d3cd703f35c181faa6d2a680a0450;hp=d1653ee5e627b67b6ce3209e864fb263174b1917;hpb=a795fa74d5a75318b1196bef048086cd64c41397;p=culture.git diff --git a/grids.py b/grids.py index d1653ee..7aec62c 100755 --- a/grids.py +++ b/grids.py @@ -869,6 +869,74 @@ class Grids(problem.Problem): # i,j=q%self.height,q//self.height # if + # @torch.compile + def task_puzzle(self, A, f_A, B, f_B): + S = 4 + i0, j0 = (self.height - S) // 2, (self.width - S) // 2 + c = torch.randperm(len(self.colors) - 1)[:4] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + f_X[...] = 0 + h = list(torch.randperm(c.size(0))) + n = torch.zeros(c.max() + 1) + for _ in range(2): + k = torch.randperm(S * S) + for q in k: + i, j = q % S + i0, q // S + j0 + if f_X[i, j] == 0: + r, s, t, u = ( + f_X[i - 1, j], + f_X[i, j - 1], + f_X[i + 1, j], + f_X[i, j + 1], + ) + r, s, t, u = torch.tensor([r, s, t, u])[torch.randperm(4)] + if r > 0 and n[r] < 6: + n[r] += 1 + f_X[i, j] = r + elif s > 0 and n[s] < 6: + n[s] += 1 + f_X[i, j] = s + elif t > 0 and n[t] < 6: + n[t] += 1 + f_X[i, j] = t + elif u > 0 and n[u] < 6: + n[u] += 1 + f_X[i, j] = u + else: + if len(h) > 0: + d = c[h.pop()] + n[d] += 1 + f_X[i, j] = d + + if n.sum() == S * S: + break + + k = 0 + for d in range(4): + while True: + ii, jj = torch.randint(self.height, (1,)), torch.randint( + self.width, (1,) + ) + e = 0 + for i in range(S): + for j in range(S): + if ( + ii + i >= self.height + or jj + j >= self.width + or ( + f_X[i + i0, j + j0] == c[d] + and X[ii + i, jj + j] > 0 + ) + ): + e = 1 + if e == 0: + break + for i in range(S): + for j in range(S): + if f_X[i + i0, j + j0] == c[d]: + X[ii + i, jj + j] = c[d] + ###################################################################### def all_tasks(self): @@ -942,6 +1010,15 @@ class Grids(problem.Problem): nrow, ) + def save_some_examples(self, result_dir): + nb, nrow = 72, 4 + for t in self.all_tasks(): + print(t.__name__) + prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t]) + self.save_quizzes( + result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow + ) + ###################################################################### @@ -967,12 +1044,12 @@ if __name__ == "__main__": # nb, nrow = 8, 2 # for t in grids.all_tasks(): - for t in [grids.task_path]: + for t in [grids.task_puzzle]: print(t.__name__) prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow) - # exit(0) + exit(0) nb = 1000