From 472c399c71cb40e83cd3242fd82c3d78a280a058 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 31 Jul 2024 09:06:42 +0200 Subject: [PATCH] Update. --- main.py | 80 ++++++++++++++++++++++++++------------------------------- 1 file changed, 36 insertions(+), 44 deletions(-) diff --git a/main.py b/main.py index d50837a..3cc536c 100755 --- a/main.py +++ b/main.py @@ -678,30 +678,30 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 def generate_c_quizz_with_generator(generator, quiz_machine, nb): generator.to(main_device) - c_quizzes = quiz_machine.problem.create_empty_quizzes( - nb, struct=("A", "f_A", "B", "f_B") - ) + struct = ("A", "f_A", "B", "f_B") + + c_quizzes = quiz_machine.problem.create_empty_quizzes(nb, struct=struct) + ar_mask = quiz_machine.make_ar_mask(c_quizzes, struct, (1, 1, 1, 1)) i = F.one_hot( torch.randint(args.nb_gpts, (c_quizzes.size(0),)), num_classes=args.nb_gpts, ) - prolog = token_prolog_0 * i + token_prolog_2 * (1 - i) - len_prolog, len_quiz = prolog.size(1), c_quizzes.size(1) - - prologued_c_quizzes = torch.cat([prolog, c_quizzes], dim=1).to(main_device) - - T = torch.arange(prologued_c_quizzes.size(1), device=prologued_c_quizzes.device)[ - None, : - ] + prolog_c_quizzes = token_prolog_0 * i + token_prolog_2 * (1 - i) + prolog_ar_mask = ar_mask.new_zeros(ar_mask.size(0), prolog_c_quizzes.size(1)) - ar_mask = ((T >= len_prolog) & ((T - len_prolog) % (len_quiz // 4) > 0)).long() + prologued_c_quizzes = torch.cat([prolog_c_quizzes, c_quizzes], dim=1).to( + main_device + ) + prologued_ar_mask = torch.cat([prolog_ar_mask, ar_mask], dim=1).to(main_device) seq_logproba = torch.zeros( prologued_c_quizzes.size(0), device=prologued_c_quizzes.device ) + generator.temperature = args.temperature_hot + with torch.autograd.no_grad(): t = generator.training generator.eval() @@ -709,28 +709,30 @@ def generate_c_quizz_with_generator(generator, quiz_machine, nb): one_batch_masked_inplace_autoregression( generator, prologued_c_quizzes, - ar_mask, + prologued_ar_mask, seq_logproba, deterministic_synthesis=False, ) generator.train(t) + generator.reset_transformations() + prologued_c_quizzes = ( prologued_c_quizzes * (prologued_c_quizzes < vocabulary_size).long() ) - return prologued_c_quizzes[:, len_prolog:].to("cpu") + return prologued_c_quizzes[:, prolog_c_quizzes.size(1) :].to("cpu") -def batches_for_generator(generator, quiz_machine, models, w_quizzes=True): +def batches_for_generator(generator, quiz_machine, models, fraction_w_quizzes=1.0): samples = [] for _ in range(args.nb_train_samples // args.batch_size): while sum([x.size(0) for x in samples]) < args.batch_size: # Generate a bunch of quizzes - if w_quizzes: + if torch.rand(1).item() <= fraction_w_quizzes: # Either we start with the world quizzes c_quizzes = quiz_machine.problem.generate_w_quizzes( args.batch_size, progress_bar=False @@ -738,7 +740,7 @@ def batches_for_generator(generator, quiz_machine, models, w_quizzes=True): else: # Or we use the generator itself to generate them c_quizzes = generate_c_quizz_with_generator( - args.batch_size, generator, quiz_machine + generator, quiz_machine, args.batch_size ) # We remove the trivial ones @@ -757,27 +759,24 @@ def batches_for_generator(generator, quiz_machine, models, w_quizzes=True): probas = seq_logproba.exp() - nu = probas <= args.proba_not_understands - u = probas >= args.proba_understands + u0 = probas <= args.proba_not_understands + u2 = probas >= args.proba_understands + u1 = (u0 | u2) == False prolog = ( - (nu.long() * token_prolog_0) - + (((nu == False) & (u == False)).long() * token_prolog_1) - + (u.long() * token_prolog_2) + (u0.long() * token_prolog_0) + + (u1.long() * token_prolog_1) + + (u2.long() * token_prolog_2) ) prologued_c_quizzes = torch.cat([prolog, c_quizzes], dim=1) - # nb_u = u.long().sum(dim=1) - # nb_nu = nu.long().sum(dim=1) - - # prologued_c_quizzes = prologued_c_quizzes[ - # (nb_u + nb_nu == args.nb_gpts) - # & (nb_nu >= 1) - # & (nb_nu <= args.max_fail_to_validate) - # ] + # nb_u2 = u2.long().sum(dim=1) + # nb_u0 = u0.long().sum(dim=1) + # prologued_c_quizzes = prologued_c_quizzes[(nb_u2 >= 1) & (nb_u0 >= 1)] - samples.append(prologued_c_quizzes) + if prologued_c_quizzes.size(0) > 0: + samples.append(prologued_c_quizzes) # Now we yield a batch @@ -788,7 +787,7 @@ def batches_for_generator(generator, quiz_machine, models, w_quizzes=True): def one_generator_epoch( - generator, quiz_machine, models, w_quizzes=True, local_device=main_device + generator, quiz_machine, models, fraction_w_quizzes, local_device=main_device ): model.to(local_device).train() @@ -796,10 +795,11 @@ def one_generator_epoch( nb_train_samples, acc_train_loss = 0, 0.0 - hard_w_quizzes = [] - src = batches_for_generator( - generator=generator, quiz_machine=quiz_machine, models=models + generator=generator, + quiz_machine=quiz_machine, + models=models, + fraction_w_quizzes=fraction_w_quizzes, ) for input in tqdm.tqdm( @@ -1047,7 +1047,7 @@ if args.test_generator: generator, quiz_machine=quiz_machine, models=models, - w_quizzes=True, + fraction_w_quizzes=1 if n_epoch < 25 else 0.5, local_device=main_device, ) @@ -1081,14 +1081,6 @@ if args.test_generator: ) log_string(f"wrote {filename}") - one_generator_epoch( - generator, - quiz_machine=quiz_machine, - models=models, - w_quizzes=False, - local_device=main_device, - ) - exit(0) ###################################################################### -- 2.39.5