From 866a55a388f0c6e8a5be1fb2bf46db442e378d9b Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 19 Jul 2024 15:55:45 +0200 Subject: [PATCH] Update. --- grids.py | 109 +++++++++++++++++++++++++++++++++++++++--------- main.py | 7 +++- quiz_machine.py | 71 +++++++------------------------ 3 files changed, 110 insertions(+), 77 deletions(-) diff --git a/grids.py b/grids.py index c2ff0d1..e1eff00 100755 --- a/grids.py +++ b/grids.py @@ -118,6 +118,61 @@ class Grids(problem.Problem): ("gray", [128, 128, 128]), ] + def make_ar_mask(self, quizzes, first=False): + S = self.height * self.width + + assert ( + ( + (quizzes[:, 0] == self.token_forward) + | (quizzes[:, 0] == self.token_backward) + ) + & (quizzes[:, 0] == quizzes[:, 1 * (S + 1)]) + & (quizzes[:, 0] == quizzes[:, 2 * (S + 1)]) + & (quizzes[:, 0] == quizzes[:, 3 * (S + 1)]) + ).all() + + T = torch.arange(quizzes.size(1), device=quizzes.device) + + if first: + forward_mask = ((T % (S + 1) != 0) & (T < 3 * (S + 1))).long() + backward_mask = ((T % (S + 1) != 0) & (T < S + 1)).long() + else: + forward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long() + backward_mask = ((T % (S + 1) != 0) & (T >= S + 1)).long() + + is_forward = (quizzes[:, 0] == self.token_forward).long() + + return ( + is_forward[:, None] * forward_mask[None, :] + + (1 - is_forward)[:, None] * backward_mask[None, :] + ) + + def p_a_flip(self, quizzes): + S = self.height * self.width + + assert ( + ( + (quizzes[:, 0] == self.token_forward) + | (quizzes[:, 0] == self.token_backward) + ) + & (quizzes[:, 0] == quizzes[:, 1 * (S + 1)]) + & (quizzes[:, 0] == quizzes[:, 2 * (S + 1)]) + & (quizzes[:, 0] == quizzes[:, 3 * (S + 1)]) + ).all() + + flipped = torch.cat( + [quizzes[:, k * (S + 1) : (k + 1) * (S + 1)] for k in range(3, -1, -1)], + dim=1, + ) + + m = (flipped[:, 0] == self.token_forward).long() + flipped[:, 0 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward + flipped[:, 1 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward + flipped[:, 2 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward + flipped[:, 3 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward + + return flipped + def __init__( self, max_nb_cached_chunks=None, @@ -158,29 +213,40 @@ class Grids(problem.Problem): ###################################################################### def frame2img(self, x, scale=15): - x = x.reshape(x.size(0), self.height, -1) + x = x.reshape(x.size(0), self.height, self.width) m = torch.logical_and(x >= 0, x < len(self.colors)).long() - x = self.colors[x * m].permute(0, 3, 1, 2) - s = x.shape - x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale) - x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale) + y = self.colors[x * m].permute(0, 3, 1, 2) + s = y.shape + y = y[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale) + y = y.reshape(s[0], s[1], s[2] * scale, s[3] * scale) - x[:, :, :, torch.arange(0, x.size(3), scale)] = 0 - x[:, :, torch.arange(0, x.size(2), scale), :] = 0 - x = x[:, :, 1:, 1:] + y[:, :, :, torch.arange(0, y.size(3), scale)] = 0 + y[:, :, torch.arange(0, y.size(2), scale), :] = 0 + y = y[:, :, 1:, 1:] for n in range(m.size(0)): for i in range(m.size(1)): for j in range(m.size(2)): - if m[n, i, j] == 0: + if x[n, i, j] == self.token_forward: + for k in range(2, scale - 2): + y[ + n, + :, + i * scale + k, + j * scale + scale - 5 - abs(k - scale // 2), + ] = 0 + + elif x[n, i, j] == self.token_backward: for k in range(2, scale - 2): - for l in [0, 1]: - x[n, :, i * scale + k, j * scale + k - l] = 0 - x[ - n, :, i * scale + scale - 1 - k, j * scale + k - l - ] = 0 + y[ + n, :, i * scale + k, j * scale + 3 + abs(k - scale // 2) + ] = 0 + # y[n, :, i * scale + k, j * scale + k - l] = 0 + # y[ + # n, :, i * scale + scale - 1 - k, j * scale + k - l + # ] = 0 - return x + return y def save_image( self, @@ -1223,6 +1289,10 @@ class Grids(problem.Problem): ) f_B = answer[1 : S + 1].view(self.height, self.width) task = tasks[torch.randint(len(tasks), (1,)).item()] + A[...] = 0 + f_A[...] = 0 + B[...] = 0 + f_B[...] = 0 task(A, f_A, B, f_B) return prompts.flatten(1), answers.flatten(1) @@ -1277,23 +1347,24 @@ if __name__ == "__main__": # exit(0) # if True: - nb, nrow = 128, 4 + nb, nrow = 8, 2 # nb, nrow = 8, 2 for t in grids.all_tasks: # for t in [grids.task_compute]: print(t.__name__) prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) + # prompts[...] = torch.randint(grids.nb_token_values(), prompts.size()) grids.save_quiz_illustrations( "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow ) - # exit(0) + exit(0) nb = 1000 - # for t in grids.all_tasks: - for t in [grids.task_compute]: + for t in grids.all_tasks: + # for t in [grids.task_compute]: start_time = time.perf_counter() prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) delay = time.perf_counter() - start_time diff --git a/main.py b/main.py index d9257db..0182e6a 100755 --- a/main.py +++ b/main.py @@ -3,6 +3,9 @@ # Any copyright is dedicated to the Public Domain. # https://creativecommons.org/publicdomain/zero/1.0/ +# > A > f(A) > B ; > f(B) +# < f(B) ; < B < f(A) < A + # Written by Francois Fleuret import math, sys, argparse, time, tqdm, os, datetime, warnings @@ -496,11 +499,11 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 v_train = validated_quizzes[:nb_for_train] quiz_machine.store_c_quizzes(v_train, for_train=True) - quiz_machine.store_c_quizzes(quiz_machine.p_a_flip(v_train), for_train=True) + quiz_machine.store_c_quizzes(quiz_machine.problem.p_a_flip(v_train), for_train=True) v_test = validated_quizzes[nb_for_train:nb_to_create] quiz_machine.store_c_quizzes(v_test, for_train=False) - quiz_machine.store_c_quizzes(quiz_machine.p_a_flip(v_test), for_train=False) + quiz_machine.store_c_quizzes(quiz_machine.problem.p_a_flip(v_test), for_train=False) ###################################################################### # save images diff --git a/quiz_machine.py b/quiz_machine.py index cc81086..046ab73 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -154,56 +154,17 @@ class QuizMachine: n_p2a = quizzes[quizzes[:, 0] == self.problem.token_forward] n_a2p = quizzes[:, 0] == self.problem.token_backward a2p = quizzes[n_a2p] - quizzes[n_a2p] = self.p_a_flip(quizzes[n_a2p]) + quizzes[n_a2p] = self.problem.p_a_flip(quizzes[n_a2p]) return torch.logical_not( self.problem.trivial_prompts_and_answers( quizzes[:, : self.prompt_len], quizzes[:, self.prompt_len :] ) ) - def p_a_flip(self, quizzes): - i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes) - - p2a_to_a2p = torch.cat( - [quizzes[:, self.prompt_len :], quizzes[:, : self.prompt_len]], - dim=1, - ) - - p2a_to_a2p[:, 0] = self.problem.token_backward - p2a_to_a2p[:, self.answer_len] = self.problem.token_backward - - a2p_to_p2a = torch.cat( - [quizzes[:, self.answer_len :], quizzes[:, : self.answer_len]], - dim=1, - ) - - a2p_to_p2a[:, 0] = self.problem.token_forward - a2p_to_p2a[:, self.prompt_len] = self.problem.token_forward - - m = i_p2a.long()[:, None] - - return m * p2a_to_a2p + (1 - m) * a2p_to_p2a - def p_a_flip_half_in_place(self, quizzes): i = torch.rand(quizzes.size(0)) < 0.5 if i.any(): - quizzes[i] = self.p_a_flip(quizzes[i]) - - def make_ar_mask(self, quizzes, first=False): - i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes) - - t = torch.arange(quizzes.size(1), device=quizzes.device) - - if first: - m_p2a = (t >= 1).long() * (t < self.prompt_len).long() - m_a2p = (t >= 1).long() * (t < self.answer_len).long() - else: - m_p2a = (t >= 1 + self.prompt_len).long() - m_a2p = (t >= 1 + self.answer_len).long() - - m = i_p2a.long()[:, None] - - return m * m_p2a + (1 - m) * m_a2p + quizzes[i] = self.problem.p_a_flip(quizzes[i]) def generate_token_sequences(self, nb): prompts, answers = self.problem.generate_prompts_and_answers(nb) @@ -261,7 +222,7 @@ class QuizMachine: n_a2p = quizzes[:, 0] == self.problem.token_backward a2p = quizzes[n_a2p] assert n_p2a.size(0) + a2p.size(0) == quizzes.size(0) - quizzes[n_a2p] = self.p_a_flip(quizzes[n_a2p]) + quizzes[n_a2p] = self.problem.p_a_flip(quizzes[n_a2p]) if show_part_to_predict: predicted_prompts = n_a2p.long() @@ -332,7 +293,7 @@ class QuizMachine: def produce_results(self, n_epoch, model, result_dir, deterministic_synthesis): def compute_accuracy(input, log_prefix=None): input = input.to(self.device) - ar_mask = self.make_ar_mask(input) + ar_mask = self.problem.make_ar_mask(input) result = input.clone() * (1 - ar_mask) seq_logproba = torch.empty(input.size(0), device=self.device) @@ -357,7 +318,7 @@ class QuizMachine: if self.back_accuracy and n_a2p.any(): # accuracy of B->A*->B*=B instead of B->A*=A - back_input = self.p_a_flip(result[n_a2p]) + back_input = self.problem.p_a_flip(result[n_a2p]) back_input[:, 1 + self.prompt_len :] = input[n_a2p, 1 : self.answer_len] _, correct[n_a2p] = compute_accuracy(back_input) @@ -471,7 +432,7 @@ class QuizMachine: c_quizzes.split(self.batch_size), logproba.split(self.batch_size) ): input = input.to(self.device) - ar_mask = self.make_ar_mask(input) + ar_mask = self.problem.make_ar_mask(input) output = model(mygpt.BracketedSequence(input)).x l[:, model.id] = ( -F.cross_entropy( @@ -506,7 +467,7 @@ class QuizMachine: c_quizzes = c_quizzes.to(self.device) result = c_quizzes.clone() - ar_mask = self.make_ar_mask(result) + ar_mask = self.problem.make_ar_mask(result) masked_inplace_autoregression( model=model, @@ -545,14 +506,13 @@ class QuizMachine: seq_logproba = torch.zeros(nb, device=self.device) if p2a_only: - c_quizzes[:, 0] = self.problem.token_forward - c_quizzes[:, self.prompt_len] = self.problem.token_forward + c_quizzes[...] = self.problem.token_forward masked_inplace_autoregression( model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=self.make_ar_mask(c_quizzes, first=True), + ar_mask=self.problem.make_ar_mask(c_quizzes, first=True), seq_logproba=seq_logproba, temperature=temperature_hot, deterministic_synthesis=False, @@ -563,7 +523,7 @@ class QuizMachine: model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=self.make_ar_mask(c_quizzes), + ar_mask=self.problem.make_ar_mask(c_quizzes), seq_logproba=seq_logproba, temperature=temperature_cold, deterministic_synthesis=False, @@ -571,14 +531,13 @@ class QuizMachine: ) else: - c_quizzes[:, 0] = self.problem.token_backward - c_quizzes[:, self.answer_len] = self.problem.token_backward + c_quizzes[...] = self.problem.token_backward masked_inplace_autoregression( model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=self.make_ar_mask(c_quizzes, first=True), + ar_mask=self.problem.make_ar_mask(c_quizzes, first=True), seq_logproba=seq_logproba, temperature=temperature_hot, deterministic_synthesis=False, @@ -589,20 +548,20 @@ class QuizMachine: model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=self.make_ar_mask(c_quizzes), + ar_mask=self.problem.make_ar_mask(c_quizzes), seq_logproba=seq_logproba, temperature=temperature_cold, deterministic_synthesis=False, device=self.device, ) - c_quizzes = self.p_a_flip(c_quizzes) + c_quizzes = self.problem.p_a_flip(c_quizzes) masked_inplace_autoregression( model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=self.make_ar_mask(c_quizzes), + ar_mask=self.problem.make_ar_mask(c_quizzes), seq_logproba=seq_logproba, temperature=temperature_cold, deterministic_synthesis=False, -- 2.20.1