X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=reasoning.py;h=9e26d64e34d1d796187d4f2fb9ad0f055562472c;hb=5a0c2432316b0a413f1769ab429d33433a94e6e1;hp=2874adc9c1c3f213bafce0c8dac53063c6463ab9;hpb=dfa00f17ae94d481ca0f8fd6ce96dbcaa4bbe06e;p=culture.git diff --git a/reasoning.py b/reasoning.py index 2874adc..9e26d64 100755 --- a/reasoning.py +++ b/reasoning.py @@ -293,7 +293,7 @@ class Reasoning(problem.Problem): X[i1:i2, j1:j2] = c[n] f_X[i1:i2, j1:j2] = c[n if n > 0 else -1] - def task_move(self, A, f_A, B, f_B): + def task_translate(self, A, f_A, B, f_B): di, dj = torch.randint(3, (2,)) - 1 nb_rec = 3 c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 @@ -406,16 +406,90 @@ class Reasoning(problem.Problem): if n < nb_rec - 1: f_X[i1, j1] = c[-1] + def task_count(self, A, f_A, B, f_B): + N = torch.randint(4, (1,)) + 2 + c = torch.randperm(len(self.colors) - 1)[:N] + 1 + + for X, f_X in [(A, f_A), (B, f_B)]: + + def contact(i, j, q): + nq, nq_diag = 0, 0 + no = 0 + + for ii, jj in [ + (i - 1, j - 1), + (i - 1, j), + (i - 1, j + 1), + (i, j - 1), + (i, j + 1), + (i + 1, j - 1), + (i + 1, j), + (i + 1, j + 1), + ]: + if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width: + if X[ii, jj] != 0 and X[ii, jj] != q: + no += 1 + + for ii, jj in [ + (i - 1, j - 1), + (i - 1, j + 1), + (i + 1, j - 1), + (i + 1, j + 1), + ]: + if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width: + if X[ii, jj] == q and X[i, jj] != q and X[ii, j] != q: + nq_diag += 1 + + for ii, jj in [(i - 1, j), (i, j - 1), (i, j + 1), (i + 1, j)]: + if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width: + if X[ii, jj] == q: + nq += 1 + + return no, nq, nq_diag + + nb = torch.zeros(N, dtype=torch.int64) + q = torch.randint(N, (self.height * self.width,)) + k = torch.randperm(self.height * self.width) + for p in range(self.height * self.width): + i, j = k[p] % self.height, k[p] // self.height + no, nq, nq_diag = contact(i, j, c[q[p]]) + if no == 0 and nq_diag == 0: + if nq == 0: + if nb[q[p]] < self.width: + X[i, j] = c[q[p]] + nb[q[p]] += 1 + if nq == 1: + X[i, j] = c[q[p]] + + for n in range(N): + for j in range(nb[n]): + f_X[n, j] = c[n] + + def task_count_(self, A, f_A, B, f_B): + N = torch.randint(3, (1,)) + 1 + c = torch.randperm(len(self.colors) - 1)[:N] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + nb = torch.randint(self.width, (3,)) + 1 + k = torch.randperm(self.height * self.width)[: nb.sum()] + p = 0 + for n in range(N): + for m in range(nb[n]): + i, j = k[p] % self.height, k[p] // self.height + X[i, j] = c[n] + f_X[n, m] = c[n] + p += 1 + ###################################################################### def generate_prompts_and_answers(self, nb, device="cpu"): tasks = [ self.task_replace_color, - self.task_move, + self.task_translate, self.task_grow, self.task_color_grow, self.task_frame, self.task_detect, + self.task_count, ] prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64) answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64) @@ -476,6 +550,6 @@ if __name__ == "__main__": prompts[:64], answers[:64], # You can add a bool to put a frame around the predicted parts - predicted_prompts[:64], - predicted_answers[:64], + # predicted_prompts[:64], + # predicted_answers[:64], )