self.task_trajectory,
self.task_bounce,
self.task_scale,
- # self.task_symbols,
+ self.task_symbols,
self.task_isometry,
- self.task_islands,
+ # self.task_islands,
]
if tasks is None:
while True:
i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values
j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values
-
+ i[:, 1] += 1
+ j[:, 1] += 1
big_enough = (
(i[:, 1] >= i[:, 0] + min_height)
& (j[:, 1] >= j[:, 0] + min_height)
if d.min() > delta:
break
- for k in range(1, nb_rec):
- X[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
-
ai, aj = i.float().mean(), j.float().mean()
q = torch.randint(3, (1,)).item() + 1
- X[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
- X[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
- X[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
- X[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
-
assert i[q] != ai and j[q] != aj
+ for Z in [X, f_X]:
+ for k in range(1, nb_rec):
+ Z[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
+ Z[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
+ Z[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
+ Z[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
+ Z[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
+
+ # f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
+ f_X[i[0] + delta // 2, j[0] + delta // 2] = c[q]
+
X[
i[0] + delta // 2 + (i[q] - ai).sign().long(),
j[0] + delta // 2 + (j[q] - aj).sign().long(),
] = c[nb_rec]
- f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
-
# @torch.compile
def task_isometry(self, A, f_A, B, f_B):
nb_rec = 3
f_X[i, s : s + w1] = c1
f_X[i, s + w1 : s + w1 + w2] = c2
+ # @torch.compile
+ # [ai1,ai2] [bi1,bi2]
+ def task_proximity(self, A, f_A, B, f_B):
+ def rec_dist(a, b):
+ ai1, aj1, ai2, aj2 = a
+ bi1, bj1, bi2, bj2 = b
+ v = max(ai1 - bi2, bi1 - ai2)
+ h = max(aj1 - bj2, bj1 - aj2)
+ return min(max(v, 0) + max(h + 1, 0), max(v + 1, 0) + max(h, 0))
+
+ nb_rec = 3
+ c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ while True:
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
+ d = [rec_dist(r[0], r[k]) for k in range(nb_rec)]
+ if min(d[1:]) == 0:
+ break
+
+ for n in range(nb_rec):
+ i1, j1, i2, j2 = r[n]
+ X[i1:i2, j1:j2] = c[n]
+ if d[n] == 0:
+ f_X[i1:i2, j1:j2] = c[0]
+ else:
+ f_X[i1:i2, j1:j2] = c[n]
+
+ # @torch.compile
+ # [ai1,ai2] [bi1,bi2]
+ def task_corners(self, A, f_A, B, f_B):
+ polarity = torch.randint(2, (1,)).item()
+ nb_rec = 3
+ c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
+
+ for n in range(nb_rec):
+ i1, j1, i2, j2 = r[n]
+ if polarity == 0:
+ X[i1, j1] = c[n]
+ X[i2 - 1, j2 - 1] = c[n]
+ else:
+ X[i1, j2 - 1] = c[n]
+ X[i2 - 1, j1] = c[n]
+ f_X[i1:i2, j1:j2] = c[n]
+
######################################################################
def trivial_prompts_and_answers(self, prompts, answers):
# nb, nrow = 8, 2
# for t in grids.all_tasks:
- for t in [grids.task_convex]:
+ # for t in [grids.task_proximity, grids.task_corners]:
+ for t in [grids.task_symbols]:
print(t.__name__)
prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
# prompts[...] = torch.randint(grids.nb_token_values(), prompts.size())
######################################################################
- def produce_results(self, n_epoch, model, result_dir, deterministic_synthesis):
+ def produce_results(
+ self, n_epoch, model, input, result_dir, deterministic_synthesis
+ ):
def compute_accuracy(input, log_prefix=None):
input = input.to(self.device)
ar_mask = self.problem.make_ar_mask(input, shape="fwd_3_bck_123")
return result, correct
- test_result, test_correct = compute_accuracy(
- model.test_w_quizzes[:2000], log_prefix="test"
- )
+ test_result, test_correct = compute_accuracy(input, log_prefix="test")
n_test_p2a = model.test_w_quizzes[:2000, 0] == self.problem.token_forward