From 6abefdb1be24f0f5d5ed034ee0ee259032c6ce0a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 30 Jul 2024 23:09:44 +0200 Subject: [PATCH] Update. --- main.py | 377 ++++++++++++++++++++++++++++++----------------------- problem.py | 20 ++- 2 files changed, 226 insertions(+), 171 deletions(-) diff --git a/main.py b/main.py index 7aeae98..50e34a8 100755 --- a/main.py +++ b/main.py @@ -675,6 +675,166 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 ###################################################################### +def generate_c_quizz_with_generator(generator, quiz_machine, nb): + generator.to(main_device) + + c_quizzes = quiz_machine.problem.create_empty_quizzes( + nb, struct=("A", "f_A", "B", "f_B") + ) + + i = F.one_hot( + torch.randint(args.nb_gpts, (c_quizzes.size(0),)), + num_classes=args.nb_gpts, + ) + + prolog = token_prolog_0 * i + token_prolog_2 * (1 - i) + len_prolog, len_quiz = prolog.size(1), c_quizzes.size(1) + + prologued_c_quizzes = torch.cat([prolog, c_quizzes], dim=1).to(main_device) + + T = torch.arange(prologued_c_quizzes.size(1), device=prologued_c_quizzes.device)[ + None, : + ] + + ar_mask = ((T >= len_prolog) & ((T - len_prolog) % (len_quiz // 4) > 0)).long() + + seq_logproba = torch.zeros( + prologued_c_quizzes.size(0), device=prologued_c_quizzes.device + ) + + with torch.autograd.no_grad(): + t = generator.training + generator.eval() + + one_batch_masked_inplace_autoregression( + generator, + prologued_c_quizzes, + ar_mask, + seq_logproba, + deterministic_synthesis=False, + ) + + generator.train(t) + + prologued_c_quizzes = ( + prologued_c_quizzes * (prologued_c_quizzes < vocabulary_size).long() + ) + + return prologued_c_quizzes[:, len_prolog:].to("cpu") + + +def batches_for_generator(generator, quiz_machine, models, w_quizzes=True): + samples = [] + + for _ in range(args.nb_train_samples // args.batch_size): + while sum([x.size(0) for x in samples]) < args.batch_size: + # Generate a bunch of quizzes + + if w_quizzes: + # Either we start with the world quizzes + c_quizzes = quiz_machine.problem.generate_w_quizzes( + args.batch_size, progress_bar=False + ) + else: + # Or we use the generator itself to generate them + c_quizzes = generate_c_quizz_with_generator( + args.batch_size, generator, quiz_machine + ) + + # We remove the trivial ones + to_keep = quiz_machine.problem.trivial(c_quizzes) == False + c_quizzes = c_quizzes[to_keep] + + # If there are remaining ones, we compute the true prolog + # that indicates how the GPTs solve it + + if c_quizzes.size(0) > 0: + seq_logproba = quiz_machine.models_logprobas( + models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) + ) + quiz_machine.models_logprobas( + models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1) + ) + + probas = seq_logproba.exp() + + nu = probas <= args.proba_not_understands + u = probas >= args.proba_understands + + prolog = ( + (nu.long() * token_prolog_0) + + (((nu == False) & (u == False)).long() * token_prolog_1) + + (u.long() * token_prolog_2) + ) + + prologued_c_quizzes = torch.cat([prolog, c_quizzes], dim=1) + + # nb_u = u.long().sum(dim=1) + # nb_nu = nu.long().sum(dim=1) + + # prologued_c_quizzes = prologued_c_quizzes[ + # (nb_u + nb_nu == args.nb_gpts) + # & (nb_nu >= 1) + # & (nb_nu <= args.max_fail_to_validate) + # ] + + samples.append(prologued_c_quizzes) + + # Now we yield a batch + + x = torch.cat(samples, dim=0) + samples = [x[args.batch_size :]] + + yield x[: args.batch_size] + + +def one_generator_epoch( + generator, quiz_machine, models, w_quizzes=True, local_device=main_device +): + model.to(local_device).train() + + optimizer = torch.optim.Adam(generator.parameters(), lr=args.learning_rate) + + nb_train_samples, acc_train_loss = 0, 0.0 + + hard_w_quizzes = [] + + src = batches_for_generator( + generator=generator, quiz_machine=quiz_machine, models=models + ) + + for input in tqdm.tqdm( + src, + dynamic_ncols=True, + desc="training", + total=args.nb_train_samples // args.batch_size, + ): + input = input.to(local_device) + + if nb_train_samples % args.batch_size == 0: + optimizer.zero_grad() + + targets = input + + output = generator(mygpt.BracketedSequence(input)).x + loss = F.cross_entropy(output.transpose(1, 2), targets) + acc_train_loss += loss.item() * input.size(0) + nb_train_samples += input.size(0) + + loss.backward() + + if nb_train_samples % args.batch_size == 0: + optimizer.step() + + train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) + + log_string(f"train_perplexity {n_epoch} generator - {train_perplexity}") + + generator.to(main_device) + + +###################################################################### + + def train_complexifier(model_gen, model_pred1, model_pred2): samples = [] perf = [] @@ -804,170 +964,6 @@ for k in range(args.nb_gpts): ###################################################################### -token_prolog_0 = vocabulary_size + 0 -token_prolog_1 = vocabulary_size + 1 -token_prolog_2 = vocabulary_size + 2 -generator_vocabulary_size = vocabulary_size + 3 - -generator = mygpt.MyGPT( - vocabulary_size=generator_vocabulary_size, - dim_model=args.dim_model, - dim_keys=args.dim_keys, - dim_hidden=args.dim_hidden, - nb_heads=args.nb_heads, - nb_blocks=args.nb_blocks, - causal=True, - dropout=args.dropout, -).to(main_device) - -generator.main_test_accuracy = 0.0 - - -###################################################################### - - -def generate_c_quizz_with_generator(generator, quiz_machine): - c_quizzes = quiz_machine.problem.create_empty_quizzes( - args.batch_size, struct=("A", "f_A", "B", "f_B") - ) - i = F.one_hot( - torch.randint(args.nb_gpts, (c_quizzes.size(0),)), - num_classes=args.nb_gpts, - ) - prolog = token_prolog_0 * i + token_prolog_2 * (1 - i) - c_quizzes = torch.cat([prolog, c_quizzes], dim=1) - ar_mask = ( - torch.arange(c_quizzes.size(1), device=c_quizzes.device)[None, :] - >= args.nb_gpts - ).long() - - one_batch_masked_inplace_autoregression( - generator, - c_quizzes, - ar_mask, - seq_logproba, - deterministic_synthesis=False, - ) - - return c_quizzes[:, args.nb_gpts :] - - -def batches_for_generator(generator=None, quiz_machine=None, device=main_device): - samples = [] - - for _ in range(args.nb_train_samples // args.batch_size): - while sum([x.size(0) for x in samples]) < args.batch_size: - # Generate a bunch of quizzes - - if generator is None: - # Either we start with the world quizzes - c_quizzes = quiz_machine.problem.generate_w_quizzes(args.batch_size) - else: - # Or we use the generator itself to generate them - c_quizzes = generate_c_quizz_with_generator(generator, quiz_machine) - - # We remove the trivial ones - to_keep = quiz_machine.problem.trivial(c_quizzes) == False - c_quizzes = c_quizzes[to_keep] - - # If there are remaining ones, we compute the true prolog - # that indicates how the GPTs solve it - - if c_quizzes.size(0) > 0: - seq_logproba = quiz_machine.models_logprobas( - models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) - ) + quiz_machine.models_logprobas( - models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1) - ) - - probas = seq_logproba.exp() - - nu = probas <= args.proba_not_understands - u = probas >= args.proba_understands - - prolog = ( - (nu.long() * token_prolog_0) - + (u.long() * token_prolog_2) - + ((nu == False & u == False).long() * token_prolog_1) - ) - - samples.append(torch.cat([prolog, c_quizzes], dim=1)) - - # Now we yield a batch - - x = torch.cat(samples, dim=0) - samples = [x[args.batch_size :]] - - yield x[: args.batch_size] - - -def one_generator_epoch( - generator, quiz_machine=None, models=None, local_device=main_device -): - model.to(local_device).train() - - optimizer = torch.optim.Adam(generator.parameters(), lr=args.learning_rate) - - nb_train_samples, acc_train_loss = 0, 0.0 - - hard_w_quizzes = [] - - full_input, full_from_w = quiz_machine.data_input(generator, split="train") - src = zip(full_input.split(args.batch_size), full_from_w.split(args.batch_size)) - - for input, from_w in tqdm.tqdm( - src, - dynamic_ncols=True, - desc="training", - total=full_input.size(0) // args.batch_size, - ): - input = input.to(local_device) - - if nb_train_samples % args.batch_size == 0: - optimizer.zero_grad() - - targets = input - - output = generator(mygpt.BracketedSequence(input)).x - loss_per_token = F.cross_entropy( - output.transpose(1, 2), targets, reduction="none" - ) - loss = loss_per_token.mean() - acc_train_loss += loss.item() * input.size(0) - - loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1) - if from_w.any(): - hard_w_quizzes.append( - (input[from_w].to("cpu"), loss_per_samples[from_w].to("cpu")) - ) - - nb_train_samples += input.size(0) - - loss.backward() - - if nb_train_samples % args.batch_size == 0: - optimizer.step() - - train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) - - log_string( - f"train_perplexity {n_epoch} generator {generator.id} {train_perplexity}" - ) - - run_tests(generator, quiz_machine) - - threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values - threshold = threshold[threshold.size(0) // 2] - - generator.hard_w_quizzes = torch.cat( - [x[l >= threshold] for x, l in hard_w_quizzes], dim=0 - ) - - generator.to(main_device) - - -###################################################################### - current_epoch = 0 if args.resume: @@ -1033,6 +1029,59 @@ if args.dirty_debug: # exit(0) +###################################################################### + +token_prolog_0 = vocabulary_size + 0 +token_prolog_1 = vocabulary_size + 1 +token_prolog_2 = vocabulary_size + 2 +generator_vocabulary_size = vocabulary_size + 3 + +generator = mygpt.MyGPT( + vocabulary_size=generator_vocabulary_size, + dim_model=args.dim_model, + dim_keys=args.dim_keys, + dim_hidden=args.dim_hidden, + nb_heads=args.nb_heads, + nb_blocks=args.nb_blocks, + causal=True, + dropout=args.dropout, +).to(main_device) + +generator.main_test_accuracy = 0.0 + +for n_epoch in range(25): + one_generator_epoch( + generator, + quiz_machine=quiz_machine, + models=models, + w_quizzes=True, + local_device=main_device, + ) + + c_quizzes = generate_c_quizz_with_generator( + generator, quiz_machine, args.batch_size + ) + + seq_logproba = quiz_machine.models_logprobas( + models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) + ) + quiz_machine.models_logprobas( + models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1) + ) + + print(seq_logproba.exp()) + + +one_generator_epoch( + generator, + quiz_machine=quiz_machine, + models=models, + w_quizzes=False, + local_device=main_device, +) + +exit(0) + + ###################################################################### for n_epoch in range(current_epoch, args.nb_epochs): diff --git a/problem.py b/problem.py index 50376d6..9bee5b2 100755 --- a/problem.py +++ b/problem.py @@ -30,7 +30,7 @@ class Problem: quizzes = self.generate_w_quizzes_(self.chunk_size) self.queue.put(quizzes.to("cpu"), block=True) - def generate_w_quizzes(self, nb): + def generate_w_quizzes(self, nb, progress_bar=True): if self.queue is None: return self.generate_w_quizzes_(nb) @@ -43,16 +43,22 @@ class Problem: n = sum([q.size(0) for q in quizzes]) - with tqdm.tqdm( - total=nb, - dynamic_ncols=True, - desc="world generation", - ) as pbar: + if progress_bar: + with tqdm.tqdm( + total=nb, + dynamic_ncols=True, + desc="world generation", + ) as pbar: + while n < nb: + q = self.queue.get(block=True) + quizzes.append(q) + n += q.size(0) + pbar.update(q.size(0)) + else: while n < nb: q = self.queue.get(block=True) quizzes.append(q) n += q.size(0) - pbar.update(q.size(0)) quizzes = torch.cat(quizzes, dim=0) assert n == quizzes.size(0) -- 2.20.1