From b6053b86cbf22a0eaa27557593f9ba4f9f13900b Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 20 Aug 2024 22:37:04 +0200 Subject: [PATCH] Update. --- grids.py | 20 +++++------ main.py | 67 ++++++++++++++++++++++++++---------- mygpt.py | 36 +++++++++---------- quiz_machine.py | 91 ++++++++++++++++++++++++++----------------------- 4 files changed, 125 insertions(+), 89 deletions(-) diff --git a/grids.py b/grids.py index 0564f3b..b12b4d6 100755 --- a/grids.py +++ b/grids.py @@ -167,11 +167,11 @@ class Grids(problem.Problem): self.check_structure(quizzes, struct) return struct - def inject_noise(self, quizzes, noise, struct, mask): + def inject_noise(self, quizzes, noise, struct, quad): assert self.check_structure(quizzes, struct=struct) S = self.height * self.width - mask = torch.tensor(mask, device=quizzes.device) + mask = torch.tensor(quad, device=quizzes.device) mask = mask[None, :, None].expand(1, 4, S + 1).clone() mask[:, :, 0] = 0 mask = mask.reshape(1, -1).expand_as(quizzes) @@ -219,20 +219,20 @@ class Grids(problem.Problem): ).values def make_quiz_mask( - self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1) + self, quizzes, struct=("A", "f_A", "B", "f_B"), quad=(0, 0, 0, 1) ): assert self.check_structure(quizzes, struct) - ar_mask = quizzes.new_zeros(quizzes.size()) + mask_ar = quizzes.new_zeros(quizzes.size()) S = self.height * self.width - a = ar_mask.reshape(ar_mask.size(0), 4, S + 1)[:, :, 1:] - a[:, 0, :] = mask[0] - a[:, 1, :] = mask[1] - a[:, 2, :] = mask[2] - a[:, 3, :] = mask[3] + a = mask_ar.reshape(mask_ar.size(0), 4, S + 1)[:, :, 1:] + a[:, 0, :] = quad[0] + a[:, 1, :] = quad[1] + a[:, 2, :] = quad[2] + a[:, 3, :] = quad[3] - return ar_mask + return mask_ar def indices_select(self, quizzes, struct=("A", "f_A", "B", "f_B")): S = self.height * self.width diff --git a/main.py b/main.py index 19c8394..8908613 100755 --- a/main.py +++ b/main.py @@ -214,6 +214,7 @@ if args.seed >= 0: torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) + torch.set_float32_matmul_precision("high") ###################################################################### @@ -326,6 +327,14 @@ def optimizer_to(optim, device): ###################################################################### +def mask_ar_to_ranks(mask_ar): + a = (mask_ar < 2).long() + a = a.cumsum(dim=1) - a + b = ((mask_ar[:, :-1] == 2) & (mask_ar[:, 1:] != 2)).long().cumsum(dim=1) + a[:, 1:] += b + return a + + def run_tests(model, quiz_machine, local_device=main_device): with torch.autograd.no_grad(): model.to(local_device).eval() @@ -335,25 +344,30 @@ def run_tests(model, quiz_machine, local_device=main_device): nb_test_samples, acc_test_loss = 0, 0.0 nb_samples_accumulated = 0 - full_input, full_mask_loss = quiz_machine.data_input( + full_input, full_mask_ar, full_mask_loss = quiz_machine.data_input( args.nb_test_samples, test_c_quiz_bags ) src = zip( - full_input.split(args.batch_size), full_mask_loss.split(args.batch_size) + full_input.split(args.batch_size), + full_mask_ar.split(args.batch_size), + full_mask_loss.split(args.batch_size), ) - for input, mask_loss in tqdm.tqdm( + for input, mask_ar, mask_loss in tqdm.tqdm( src, dynamic_ncols=True, desc="test", total=full_input.size(0) // args.batch_size, ): input = input.to(local_device) + mask_ar = mask_ar.to(local_device) mask_loss = mask_loss.to(local_device) targets = input - output = model(mygpt.BracketedSequence(input)).x + output = model( + mygpt.BracketedSequence(input, ranks=mask_ar_to_ranks(mask_ar)) + ).x loss_per_token = F.cross_entropy( output.transpose(1, 2), targets, reduction="none" ) @@ -365,7 +379,7 @@ def run_tests(model, quiz_machine, local_device=main_device): log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}") - input, _ = quiz_machine.data_input(1000, test_c_quiz_bags) + input, _, _ = quiz_machine.data_input(1000, test_c_quiz_bags) model.test_accuracy = quiz_machine.produce_results( n_epoch=n_epoch, @@ -387,18 +401,24 @@ def one_epoch(model, quiz_machine, local_device=main_device): nb_train_samples, acc_train_loss = 0, 0.0 - full_input, full_mask_loss = quiz_machine.data_input( + full_input, full_mask_ar, full_mask_loss = quiz_machine.data_input( args.nb_train_samples, train_c_quiz_bags ) - src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)) - for input, mask_loss in tqdm.tqdm( + src = zip( + full_input.split(args.batch_size), + full_mask_ar.split(args.batch_size), + full_mask_loss.split(args.batch_size), + ) + + for input, mask_ar, mask_loss in tqdm.tqdm( src, dynamic_ncols=True, desc="training", total=full_input.size(0) // args.batch_size, ): input = input.to(local_device) + mask_ar = mask_ar.to(local_device) mask_loss = mask_loss.to(local_device) if nb_train_samples % args.batch_size == 0: @@ -406,7 +426,9 @@ def one_epoch(model, quiz_machine, local_device=main_device): targets = input - output = model(mygpt.BracketedSequence(input)).x + output = model( + mygpt.BracketedSequence(input, ranks=mask_ar_to_ranks(mask_ar)) + ).x loss_per_token = F.cross_entropy( output.transpose(1, 2), targets, reduction="none" @@ -456,10 +478,10 @@ def model_modifier_cold(model): c_quizzes_procedure = [ - # (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_modifier_hot), (("f_B", "f_A", "A", "B"), (1, 1, 1, 1), model_modifier_hot), - # (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), model_modifier_hot), - # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold), + # (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_modifier_hot), + # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_modifier_cold), + # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold), ] ###################################################################### @@ -580,6 +602,8 @@ def create_c_quizzes( procedure=c_quizzes_procedure, ) + log_string(f"nb_generated_quizzes {c_quizzes.size(0)}") + nb_generated += c_quizzes.size(0) # We discard the trivial ones, according to a criterion @@ -589,6 +613,8 @@ def create_c_quizzes( c_quizzes = c_quizzes[to_keep] + log_string(f"nb_non_trivial_quizzes {c_quizzes.size(0)}") + # Keep only the quizzes that the main model cannot solve solved_c_quizzes = c_quizzes.clone() @@ -600,11 +626,15 @@ def create_c_quizzes( mask=(0, 0, 0, 1), ) + log_string(f"nb_generated_quizzes {c_quizzes.size(0)}") + main_probas = model_proba_solutions(main_model, main_solution) - log_string(f"main_probas {main_probas}") + # log_string(f"main_probas {main_probas}") keep = main_probas < args.proba_not_understands c_quizzes = c_quizzes[keep] + log_string(f"nb_not_understood_quizzes {c_quizzes.size(0)}") + # If there are some quizzes that the main model cannot solve, # pick the most confident solution @@ -623,14 +653,17 @@ def create_c_quizzes( ) probas = model_proba_solutions(model, solution) - log_string(f"probas {probas}") + # log_string(f"probas {probas}") keep = probas >= c_quizzes_proba c_quizzes = solution[keep] c_quizzes_proba[keep] = probas[keep] keep = c_quizzes_proba >= args.proba_understands - recorded.append(c_quizzes_proba[keep]) - nb_validated += keep.long().sum() + c_quizzes = c_quizzes[keep] + + log_string(f"nb_kept {c_quizzes.size(0)} total nb_validated {nb_validated}") + recorded.append(c_quizzes.clone().to("cpu")) + nb_validated += c_quizzes.size(0) duration = time.perf_counter() - start_time @@ -683,7 +716,6 @@ for k in range(args.nb_gpts): dim_hidden=args.dim_hidden, nb_heads=args.nb_heads, nb_blocks=args.nb_blocks, - compute_attzero=compute_causal_attzero, dropout=args.dropout, ).to(main_device) @@ -889,7 +921,6 @@ for n_epoch in range(current_epoch, args.nb_epochs): dim_hidden=args.dim_hidden, nb_heads=args.nb_heads, nb_blocks=args.nb_blocks, - compute_attzero=compute_causal_attzero, dropout=args.dropout, ).to(main_device) model.load_state_dict(new_model.state_dict()) diff --git a/mygpt.py b/mygpt.py index f716fe5..c69c899 100755 --- a/mygpt.py +++ b/mygpt.py @@ -76,10 +76,11 @@ class RandomBypass(nn.Module): class BracketedSequence: - def __init__(self, x, first=None, nb=None): + def __init__(self, x, first=None, nb=None, ranks=None): self.x = x self.first = 0 if first is None else first self.nb = x.size(1) if nb is None else nb + self.ranks = ranks def slice(self): return self.x[:, self.first : self.first + self.nb] @@ -104,7 +105,7 @@ class CacheWrapper(nn.Module): else: self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice()) - return BracketedSequence(self.cache_y, bs.first, bs.nb) + return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.ranks) ############################## @@ -116,7 +117,7 @@ class WithResidual(nn.Module): self.f = f[0] if len(f) == 1 else nn.Sequential(*f) def forward(self, bs): - return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb) + return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.ranks) ############################## @@ -147,7 +148,7 @@ class AddPositionalEncoding(nn.Module): bs.slice() + self.pe[bs.first : bs.first + bs.nb] ) - return BracketedSequence(self.cache_y, bs.first, bs.nb) + return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.ranks) ############################## @@ -184,7 +185,6 @@ class QKVAttention(nn.Module): dim_qk, dim_v, nb_heads=1, - compute_attzero=None, attention_dropout=0.0, ): super().__init__() @@ -192,7 +192,6 @@ class QKVAttention(nn.Module): def randw(*d): return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) - self.compute_attzero = compute_attzero self.attention_dropout = attention_dropout self.record_attention = False @@ -234,16 +233,18 @@ class QKVAttention(nn.Module): "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_kv.first + bs_kv.nb] ) / math.sqrt(self.w_q.size(1)) - if self.compute_attzero is not None: - if bs_q.first == 0: - self.cache_attzero = self.compute_attzero( - torch.arange(x_q.size(1), device=q.device)[:, None], - torch.arange(x_kv.size(1), device=q.device)[None, :], - )[None, None, :, :] + t = torch.arange(x_q.size(1), device=a.device) + + if bs_q.ranks is not None: a = a.masked_fill( - self.cache_attzero[ - :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_kv.first + bs_kv.nb - ], + ( + bs_q.ranks[:, None, bs_q.first : bs_q.first + bs_q.nb, None] + <= bs_kv.ranks[:, None, None, : bs_kv.first + bs_kv.nb] + ) + & ( + t[None, None, bs_q.first : bs_q.first + bs_q.nb, None] + != t[None, None, None, : bs_kv.first + bs_kv.nb] + ), float("-inf"), ) @@ -297,7 +298,6 @@ class BlockSummarizer(nn.Module): dim_qk=dim_keys, dim_v=dim_model // nb_heads, nb_heads=nb_heads, - compute_attzero=compute_attzero, attention_dropout=dropout, ) @@ -310,7 +310,7 @@ class ShiftByOne(nn.Module): super().__init__() def forward(self, bs): - return BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) + return BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.ranks) class MyGPT(nn.Module): @@ -322,7 +322,6 @@ class MyGPT(nn.Module): dim_hidden, nb_heads, nb_blocks, - compute_attzero=None, dropout=0.0, len_max=1e5, ): @@ -354,7 +353,6 @@ class MyGPT(nn.Module): dim_qk=dim_keys, dim_v=dim_model // nb_heads, nb_heads=nb_heads, - compute_attzero=compute_attzero, attention_dropout=dropout, ), ), diff --git a/quiz_machine.py b/quiz_machine.py index 1acd7ad..d209a07 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -19,7 +19,7 @@ import threading ###################################################################### -# ar_mask is a tensor with 0s and 1s, of same shape as input, with +# mask_ar is a tensor with 0s and 1s, of same shape as input, with # 1s where tokens should be generated. The others are kept # unchanged. @@ -27,35 +27,40 @@ import threading def one_batch_masked_inplace_autoregression( model, input, - ar_mask, + mask_ar, acc_seq_logprobas, deterministic_synthesis=False, ): if input.size(0) == 0: return - to_generate = (ar_mask.sum(0) > 0).nonzero() + mask = (mask_ar > 0).long() + to_generate = (mask.sum(0) > 0).nonzero() + + indices_1 = list(((mask_ar == 1).long().sum(0) > 0).nonzero()) + [mask.size(1)] if to_generate.min() > 0: model( BracketedSequence(input, 0, to_generate.min()) ) # Needed to initialize the model's cache - for s in range(to_generate.min(), to_generate.max() + 1): - output = model(BracketedSequence(input, s, 1)).x - logits = output[:, s] + s = to_generate.min() + + for s, u in zip(indices_1[:-1], indices_1[1:]): + logits = model(BracketedSequence(input, s, u - s)).x if deterministic_synthesis: - t_next = logits.argmax(-1) + t_next = logits.argmax(dim=2) else: dist = torch.distributions.categorical.Categorical(logits=logits) t_next = dist.sample() - all_n = torch.arange(t_next.size(0)) - - acc_seq_logprobas += ar_mask[:, s] * logits.log_softmax(dim=1)[all_n, t_next] + acc_seq_logprobas += ( + mask + * logits.log_softmax(dim=1).gather(dim=2, index=t_next[:, :, None])[:, :, 0] + ).sum(dim=1) - input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] + input[...] = mask * t_next + (1 - mask) * input ###################################################################### @@ -81,12 +86,12 @@ class QuizMachine: self.answer_len = None self.prompt_noise = prompt_noise - # struct, mask_generate, mask_noise, mask_loss + # - struct, quad_generate, quad_noise, quad_loss self.train_structures = [ - (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), - (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), - (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), - (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), + (("A", "f_A", "B", "f_B"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)), + (("f_A", "A", "f_B", "B"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)), + (("B", "f_B", "A", "f_A"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)), + (("f_B", "B", "f_A", "A"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)), (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)), ] @@ -101,15 +106,15 @@ class QuizMachine: self, model, input, - ar_mask, + mask_ar, seq_logprobas, progress_bar_desc=None, ): - assert input.size() == ar_mask.size() + assert input.size() == mask_ar.size() batches = zip( input.split(self.batch_size), - ar_mask.split(self.batch_size), + mask_ar.split(self.batch_size), seq_logprobas.split(self.batch_size), ) @@ -125,11 +130,11 @@ class QuizMachine: t = model.training model.eval() - for input, ar_mask, seq_logprobas in batches: + for input, mask_ar, seq_logprobas in batches: one_batch_masked_inplace_autoregression( model=model, input=input, - ar_mask=ar_mask, + mask_ar=mask_ar, acc_seq_logprobas=seq_logprobas, deterministic_synthesis=False, ) @@ -158,40 +163,44 @@ class QuizMachine: quizzes, structs=[s for s, _, _, _ in self.train_structures] ) + quiz_mask_ar = quizzes.new_full(quizzes.size(), 1) quiz_mask_loss = quizzes.new_full(quizzes.size(), 1) - if self.prompt_noise > 0.0: - for struct, _, mask_noise, mask_loss in self.train_structures: - i = self.problem.indices_select(quizzes=quizzes, struct=struct) - if i.any(): + for struct, quad_ar, quad_noise, quad_loss in self.train_structures: + i = self.problem.indices_select(quizzes=quizzes, struct=struct) + if i.any(): + if self.prompt_noise > 0.0: quizzes[i] = self.problem.inject_noise( - quizzes[i], self.prompt_noise, struct=struct, mask=mask_noise - ) - quiz_mask_loss[i] = self.make_quiz_mask( - quizzes=quizzes[i], struct=struct, mask=mask_loss + quizzes[i], self.prompt_noise, struct=struct, quad=quad_noise ) + quiz_mask_ar[i] = self.make_quiz_mask( + quizzes=quizzes[i], struct=struct, quad=quad_ar + ) + quiz_mask_loss[i] = self.make_quiz_mask( + quizzes=quizzes[i], struct=struct, quad=quad_loss + ) - return quizzes, quiz_mask_loss + return quizzes, quiz_mask_ar, quiz_mask_loss ###################################################################### - def make_quiz_mask(self, quizzes, struct, mask): + def make_quiz_mask(self, quizzes, struct, quad): assert struct in [s for s, _, _, _ in self.train_structures] - return self.problem.make_quiz_mask(quizzes, struct=struct, mask=mask) + return self.problem.make_quiz_mask(quizzes, struct=struct, quad=quad) ###################################################################### - def predict(self, model, quizzes, struct, mask): + def predict(self, model, quizzes, struct, quad_ar): quizzes = quizzes.to(self.device) - ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, mask=mask) - result = quizzes * (1 - ar_mask) + mask_ar = self.make_quiz_mask(quizzes=quizzes, struct=struct, quad=quad_ar) + result = quizzes * (mask_ar == 0).long() seq_logprobas = torch.zeros(quizzes.size(0), device=self.device) self.autoregression( model=model, input=result, - ar_mask=ar_mask, + mask_ar=mask_ar, seq_logprobas=seq_logprobas, progress_bar_desc="autoregression", ) @@ -215,16 +224,14 @@ class QuizMachine: nb = 0 # We consider all the configurations that we train for - for struct, mask_generate, _, _ in self.test_structures: + for struct, quad_ar, _, _ in self.test_structures: i = self.problem.indices_select(quizzes=input, struct=struct) nb += i.long().sum() result[i], correct[i], _ = self.predict( - model=model, quizzes=input[i], struct=struct, mask=mask_generate + model=model, quizzes=input[i], struct=struct, quad=quad_ar ) - predicted_parts[i] = torch.tensor(mask_generate, device=self.device)[ - None, : - ] + predicted_parts[i] = torch.tensor(quad_ar, device=self.device)[None, :] solution_is_deterministic = predicted_parts[i].sum(dim=-1) == 1 correct[i] = (2 * correct[i] - 1) * (solution_is_deterministic).long() @@ -351,7 +358,7 @@ class QuizMachine: self.autoregression( model=model_for_generation, input=c_quizzes, - ar_mask=self.make_quiz_mask(c_quizzes, s, m), + mask_ar=self.make_quiz_mask(c_quizzes, s, m), seq_logprobas=seq_logprobas, progress_bar_desc=f"autoregression {n_step+1}/{len(procedure)}", ) -- 2.39.5