X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=grids.py;h=659bd6c088d362a6dde33f5538e49b8c12dd77f3;hb=57a13bdaf395838f93dcd67dce3151e2ed9eb3f1;hp=247c146d16140ed24cc636d1baf74dc21aeefbeb;hpb=c9c9df43a6b97e5b3e81c8cf05d2f1b3010dea05;p=culture.git diff --git a/grids.py b/grids.py index 247c146..659bd6c 100755 --- a/grids.py +++ b/grids.py @@ -579,38 +579,52 @@ class Grids(problem.Problem): X[i, j] = c[1] f_X[0:2, 0:2] = c[1] - def task_islands(self, A, f_A, B, f_B): + def task_symbols(self, A, f_A, B, f_B): + nb_rec = 4 + c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 + delta = 3 for X, f_X in [(A, f_A), (B, f_B)]: while True: - i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,)) - if ( - i == 0 - or i == self.height - 1 - or j == 0 - or j == self.width - 1 - or X[i, j] == 1 - ): - break - while True: - di, dj = torch.randint(3, (2,)) - 1 - if abs(di) + abs(dj) > 0: - break - X[i, j] = 1 - while True: - i, j = i + di, j + dj - if i < 0 or i >= self.height or j < 0 or j >= self.width: - break - b = ( - i == 0 - or i == self.height - 1 - or j == 0 - or j == self.width - 1 - or X[i, j] == 1 + i, j = torch.randint(self.height - delta + 1, (nb_rec,)), torch.randint( + self.width - delta + 1, (nb_rec,) ) - X[i, j] = 1 - if b: + d = (i[None, :] - i[:, None]).abs().max((j[None, :] - j[:, None]).abs()) + d.fill_diagonal_(delta + 1) + if d.min() > delta: break + for k in range(1, nb_rec): + X[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k] + + ai, aj = i.float().mean(), j.float().mean() + + q = torch.randint(3, (1,)) + 1 + + X[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0] + X[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0] + X[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0] + X[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0] + + assert i[q] != ai and j[q] != aj + + X[ + i[0] + delta // 2 + (i[q] - ai).sign().long(), + j[0] + delta // 2 + (j[q] - aj).sign().long(), + ] = c[nb_rec] + + f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q] + + def task_islands(self, A, f_A, B, f_B): + pass + + # for X, f_X in [(A, f_A), (B, f_B)]: + # n = torch.arange(self.height * self.width).reshape(self.height, self.width) + # k = torch.randperm(self.height * self.width) + # X[...]=-1 + # for q in k: + # i,j=q%self.height,q//self.height + # if + ###################################################################### def all_tasks(self): @@ -625,9 +639,16 @@ class Grids(problem.Problem): self.task_trajectory, self.task_bounce, self.task_scale, + self.task_symbols, # self.task_islands, ] + def trivial_prompts_and_answers(self, prompts, answers): + S = self.height * self.width + Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S] + f_Bs = answers + return (B_s == f_Bs).long().min(dim=-1).values > 0 + def generate_prompts_and_answers(self, nb, tasks=None, device="cpu"): if tasks is None: tasks = self.all_tasks() @@ -681,8 +702,8 @@ if __name__ == "__main__": grids = Grids() - for t in grids.all_tasks(): - # for t in [grids.task_islands]: + # for t in grids.all_tasks(): + for t in [grids.task_islands]: print(t.__name__) prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t]) grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4)