.item()
)
- self.logger(
- f"back_accuracy {n_epoch=} {model.id=} {nb_correct=} {nb_total=}"
- )
-
n_backward = input[:, 0] == self.token_backward
back_input = self.reverse_time(result[n_backward])
n_backward, 1 : 1 + self.answer_len
]
back_nb_total, back_nb_correct = compute_accuracy(back_input)
+
+ self.logger(
+ f"accuracy {n_epoch=} {model.id=} {nb_correct} / {nb_total}"
+ )
self.logger(
- f"back_accuracy {n_epoch=} {model.id=} {back_nb_correct=} {back_nb_total=}"
+ f"back_accuracy {n_epoch=} {model.id=} {back_nb_correct} / {back_nb_total}"
)
+
nb_total += back_nb_total
nb_correct += back_nb_correct
+ else:
+ self.logger(
+ f"accuracy {n_epoch=} {model.id=} {nb_correct} / {nb_total}"
+ )
else:
nb_total = input.size(0)
X[i1:i2, j1:j2] = c[n]
f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
- def task_move(self, A, f_A, B, f_B):
+ def task_translate(self, A, f_A, B, f_B):
di, dj = torch.randint(3, (2,)) - 1
nb_rec = 3
c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
if n < nb_rec - 1:
f_X[i1, j1] = c[-1]
+ def task_count(self, A, f_A, B, f_B):
+ N = torch.randint(3, (1,)) + 1
+ c = torch.randperm(len(self.colors) - 1)[:N] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ nb = torch.randint(self.width, (3,)) + 1
+ k = torch.randperm(self.height * self.width)[: nb.sum()]
+ p = 0
+ for n in range(N):
+ for m in range(nb[n]):
+ i, j = k[p] % self.height, k[p] // self.height
+ X[i, j] = c[n]
+ f_X[n, m] = c[n]
+ p += 1
+
######################################################################
def generate_prompts_and_answers(self, nb, device="cpu"):
tasks = [
self.task_replace_color,
- self.task_move,
+ self.task_translate,
self.task_grow,
self.task_color_grow,
self.task_frame,
self.task_detect,
+ self.task_count,
]
prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
prompts[:64],
answers[:64],
# You can add a bool to put a frame around the predicted parts
- predicted_prompts[:64],
- predicted_answers[:64],
+ # predicted_prompts[:64],
+ # predicted_answers[:64],
)