problem,
nb_train_samples,
nb_test_samples,
+ back_accuracy,
batch_size,
result_dir,
logger,
self.nb_token_values = v + 2
self.problem = problem
+ self.back_accuracy = back_accuracy
self.batch_size = batch_size
self.device = device
self.logger = logger
self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
):
def compute_accuracy(input):
- input = input[:nmax]
ar_mask = self.make_ar_mask(input)
result = input.clone() * (1 - ar_mask)
seq_logproba = torch.empty(input.size(0), device=self.device)
device=self.device,
)
- nb_total = input.size(0)
- nb_correct = (input == result).long().min(dim=1).values.sum()
+ if self.back_accuracy:
+ n_forward = input[:, 0] == self.token_forward
+ nb_total = input[n_forward].size(0)
+ nb_correct = (
+ (input[n_forward] == result[n_forward])
+ .long()
+ .min(dim=1)
+ .values.sum()
+ )
+
+ n_backward = input[:, 0] == self.token_backward
+ back_input = self.reverse_time(result[n_backward])
+ if back_input.size(0) > 0:
+ back_input[:, 2 + self.prompt_len :] = input[
+ n_backward, 2 + self.prompt_len :
+ ]
+ back_nb_total, back_nb_correct = compute_accuracy(back_input)
+ nb_total += back_nb_total
+ nb_correct += back_nb_correct
+ else:
+ nb_total = input.size(0)
+ nb_correct = (input == result).long().min(dim=1).values.sum()
return nb_total, nb_correct
- train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes)
+ train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes[:nmax])
self.logger(
f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
)
- test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes)
+ test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes[:nmax])
self.logger(
f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
######################################################################
def frame2img(self, x, scale=15):
+ x = x.reshape(x.size(0), self.height, -1)
+ m = torch.logical_and(x >= 0, x < self.nb_token_values()).long()
+ x = self.colors[x * m].permute(0, 3, 1, 2)
+ s = x.shape
+ x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
+ x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
+
+ x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
+ x[:, :, torch.arange(0, x.size(2), scale), :] = 0
+ x = x[:, :, 1:, 1:]
+
+ for n in range(m.size(0)):
+ for i in range(m.size(1)):
+ for j in range(m.size(2)):
+ if m[n, i, j] == 0:
+ for k in range(2, scale - 2):
+ for l in [0, 1]:
+ x[n, :, i * scale + k, j * scale + k - l] = 0
+ x[
+ n, :, i * scale + scale - 1 - k, j * scale + k - l
+ ] = 0
+
+ return x
+
+ def frame2img_(self, x, scale=15):
x = x.reshape(x.size(0), self.height, -1)
x = self.colors[x].permute(0, 3, 1, 2)
s = x.shape
# non-overlapping rectangles quickly, but made the generation of
# 100k samples go from 1h50 with a lame pure python code to 3min30s
# with this one.
- def rec_coo(self, x, n, min_height=3, min_width=3):
- K = 3
- N = 200
+ def rec_coo(self, nb_rec, min_height=3, min_width=3):
+ nb_trials = 200
while True:
v = (
(
- torch.rand(N * K, self.height + 1, device=self.device)
+ torch.rand(nb_trials * nb_rec, self.height + 1, device=self.device)
.sort(dim=-1)
.indices
< 2
h = (
(
- torch.rand(N * K, self.width + 1, device=self.device)
+ torch.rand(nb_trials * nb_rec, self.width + 1, device=self.device)
.sort(dim=-1)
.indices
< 2
)
v, h = v[i], h[i]
- v = v[: v.size(0) - v.size(0) % K]
- h = h[: h.size(0) - h.size(0) % K]
- v = v.reshape(v.size(0) // K, K, -1)
- h = h.reshape(h.size(0) // K, K, -1)
+ v = v[: v.size(0) - v.size(0) % nb_rec]
+ h = h[: h.size(0) - h.size(0) % nb_rec]
+ v = v.reshape(v.size(0) // nb_rec, nb_rec, -1)
+ h = h.reshape(h.size(0) // nb_rec, nb_rec, -1)
r = v[:, :, :, None] * h[:, :, None, :]
######################################################################
def task_replace_color(self, A, f_A, B, f_B):
- N = 3
- c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
+ nb_rec = 3
+ c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
- r = self.rec_coo(X, N)
- for n in range(N):
+ r = self.rec_coo(nb_rec)
+ for n in range(nb_rec):
i1, j1, i2, j2 = r[n]
X[i1:i2, j1:j2] = c[n]
f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
def task_move(self, A, f_A, B, f_B):
di, dj = torch.randint(2, (2,)) * 2 - 1
- N = 3
- c = torch.randperm(len(self.colors) - 1)[:N] + 1
+ nb_rec = 3
+ c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
while True:
- r = self.rec_coo(X, N)
- i1, j1, i2, j2 = r[N - 1]
+ r = self.rec_coo(nb_rec)
+ i1, j1, i2, j2 = r[nb_rec - 1]
if (
i1 + di >= 0
and i2 + di < X.size(0)
):
break
- for n in range(N):
+ for n in range(nb_rec):
i1, j1, i2, j2 = r[n]
X[i1:i2, j1:j2] = c[n]
- if n == N - 1:
+ if n == nb_rec - 1:
f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
else:
f_X[i1:i2, j1:j2] = c[n]
def task_grow(self, A, f_A, B, f_B):
di, dj = torch.randint(2, (2,)) * 2 - 1
- N = 3
- c = torch.randperm(len(self.colors) - 1)[:N] + 1
+ nb_rec = 3
+ c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
direction = torch.randint(2, (1,))
for X, f_X in [(A, f_A), (B, f_B)]:
while True:
- r = self.rec_coo(X, N)
- i1, j1, i2, j2 = r[N - 1]
+ r = self.rec_coo(nb_rec)
+ i1, j1, i2, j2 = r[nb_rec - 1]
if i1 + 3 < i2 and j1 + 3 < j2:
break
- for n in range(N):
+ for n in range(nb_rec):
i1, j1, i2, j2 = r[n]
- if n == N - 1:
+ if n == nb_rec - 1:
if direction == 0:
X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
f_X[i1:i2, j1:j2] = c[n]
def task_color_grow(self, A, f_A, B, f_B):
di, dj = torch.randint(2, (2,)) * 2 - 1
- N = 3
- c = torch.randperm(len(self.colors) - 1)[: 2 * N] + 1
+ nb_rec = 3
+ c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
direction = torch.randint(4, (1,))
for X, f_X in [(A, f_A), (B, f_B)]:
- r = self.rec_coo(X, N)
- for n in range(N):
+ r = self.rec_coo(nb_rec)
+ for n in range(nb_rec):
i1, j1, i2, j2 = r[n]
X[i1:i2, j1:j2] = c[2 * n]
f_X[i1:i2, j1:j2] = c[2 * n]
if direction == 0:
i = (i1 + i2) // 2
X[i : i + 1, j1:j2] = c[2 * n + 1]
- if n == N - 1:
+ if n == nb_rec - 1:
f_X[i:i2, j1:j2] = c[2 * n + 1]
else:
f_X[i : i + 1, j1:j2] = c[2 * n + 1]
elif direction == 1:
i = (i1 + i2 - 1) // 2
X[i : i + 1, j1:j2] = c[2 * n + 1]
- if n == N - 1:
+ if n == nb_rec - 1:
f_X[i1 : i + 1, j1:j2] = c[2 * n + 1]
else:
f_X[i : i + 1, j1:j2] = c[2 * n + 1]
elif direction == 2:
j = (j1 + j2) // 2
X[i1:i2, j : j + 1] = c[2 * n + 1]
- if n == N - 1:
+ if n == nb_rec - 1:
f_X[i1:i2, j:j2] = c[2 * n + 1]
else:
f_X[i1:i2, j : j + 1] = c[2 * n + 1]
elif direction == 3:
j = (j1 + j2 - 1) // 2
X[i1:i2, j : j + 1] = c[2 * n + 1]
- if n == N - 1:
+ if n == nb_rec - 1:
f_X[i1:i2, j1 : j + 1] = c[2 * n + 1]
else:
f_X[i1:i2, j : j + 1] = c[2 * n + 1]
def task_frame(self, A, f_A, B, f_B):
- N = 3
- c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
+ nb_rec = 3
+ c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
- r = self.rec_coo(X, N)
- for n in range(N):
+ r = self.rec_coo(nb_rec)
+ for n in range(nb_rec):
i1, j1, i2, j2 = r[n]
X[i1:i2, j1:j2] = c[n]
f_X[i1:i2, j1:j2] = c[n]
- if n == N - 1:
+ if n == nb_rec - 1:
f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0
def task_detect(self, A, f_A, B, f_B):
- N = 3
- c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
+ nb_rec = 3
+ c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
- r = self.rec_coo(X, N)
- for n in range(N):
+ r = self.rec_coo(nb_rec)
+ for n in range(nb_rec):
i1, j1, i2, j2 = r[n]
X[i1:i2, j1:j2] = c[n]
- f_X[i1, j1] = c[-1]
+ if n < nb_rec - 1:
+ f_X[i1, j1] = c[-1]
######################################################################
reasoning.save_quizzes(
"/tmp",
"test",
- prompts[:36],
- answers[:36],
+ prompts[:64],
+ answers[:64],
# You can add a bool to put a frame around the predicted parts
# predicted_prompts, predicted_answers
)