X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=reasoning.py;h=c545e97f0b6517d7bbec6d163a55b447c2e0a989;hb=cb64cb72d5fff242cd90c13e4d8b9ca43e6e8837;hp=9e26d64e34d1d796187d4f2fb9ad0f055562472c;hpb=5a0c2432316b0a413f1769ab429d33433a94e6e1;p=culture.git diff --git a/reasoning.py b/reasoning.py index 9e26d64..c545e97 100755 --- a/reasoning.py +++ b/reasoning.py @@ -465,19 +465,95 @@ class Reasoning(problem.Problem): for j in range(nb[n]): f_X[n, j] = c[n] - def task_count_(self, A, f_A, B, f_B): - N = torch.randint(3, (1,)) + 1 - c = torch.randperm(len(self.colors) - 1)[:N] + 1 + def task_trajectory(self, A, f_A, B, f_B): + c = torch.randperm(len(self.colors) - 1)[:2] + 1 for X, f_X in [(A, f_A), (B, f_B)]: - nb = torch.randint(self.width, (3,)) + 1 - k = torch.randperm(self.height * self.width)[: nb.sum()] - p = 0 - for n in range(N): - for m in range(nb[n]): - i, j = k[p] % self.height, k[p] // self.height - X[i, j] = c[n] - f_X[n, m] = c[n] - p += 1 + while True: + di, dj = torch.randint(7, (2,)) - 3 + i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,)) + if ( + abs(di) + abs(dj) > 0 + and i + 2 * di >= 0 + and i + 2 * di < self.height + and j + 2 * dj >= 0 + and j + 2 * dj < self.width + ): + break + + k = 0 + while ( + i + k * di >= 0 + and i + k * di < self.height + and j + k * dj >= 0 + and j + k * dj < self.width + ): + if k < 2: + X[i + k * di, j + k * dj] = c[k] + f_X[i + k * di, j + k * dj] = c[min(k, 1)] + k += 1 + + def task_bounce(self, A, f_A, B, f_B): + c = torch.randperm(len(self.colors) - 1)[:3] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + + def free(i, j): + return ( + i >= 0 + and i < self.height + and j >= 0 + and j < self.width + and f_X[i, j] == 0 + ) + + while True: + f_X[...] = 0 + X[...] = 0 + + for _ in range((self.height * self.width) // 10): + i, j = torch.randint(self.height, (1,)), torch.randint( + self.width, (1,) + ) + X[i, j] = c[0] + f_X[i, j] = c[0] + + while True: + di, dj = torch.randint(7, (2,)) - 3 + if abs(di) + abs(dj) == 1: + break + + i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,)) + + X[i, j] = c[1] + f_X[i, j] = c[1] + l = 0 + + while True: + l += 1 + if free(i + di, j + dj): + pass + elif free(i - dj, j + di): + di, dj = -dj, di + if free(i + dj, j - di): + if torch.rand(1) < 0.5: + di, dj = -di, -dj + elif free(i + dj, j - di): + di, dj = dj, -di + else: + break + + i, j = i + di, j + dj + f_X[i, j] = c[2] + if l <= 1: + X[i, j] = c[2] + + if l >= self.width: + break + + f_X[i, j] = c[1] + X[i, j] = c[1] + + if l > 3: + break ###################################################################### @@ -490,6 +566,8 @@ class Reasoning(problem.Problem): self.task_frame, self.task_detect, self.task_count, + self.task_trajectory, + self.task_bounce, ] prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64) answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)