From 5b5adeae3d827dd26d7b9e304a04faef0c14c68c Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 31 Jul 2024 07:07:28 +0200 Subject: [PATCH] Update. --- main.py | 105 ++++++++++++++++++++++++++++++-------------------------- 1 file changed, 57 insertions(+), 48 deletions(-) diff --git a/main.py b/main.py index 50e34a8..d50837a 100755 --- a/main.py +++ b/main.py @@ -107,7 +107,7 @@ parser.add_argument("--nb_rounds", type=int, default=1) parser.add_argument("--dirty_debug", action="store_true", default=False) -parser.add_argument("--autoencoder_dim", type=int, default=-1) +parser.add_argument("--test_generator", action="store_true", default=False) ###################################################################### @@ -1023,64 +1023,73 @@ if args.dirty_debug: ###################################################################### -# DIRTY TEST +if args.test_generator: + token_prolog_0 = vocabulary_size + 0 + token_prolog_1 = vocabulary_size + 1 + token_prolog_2 = vocabulary_size + 2 + generator_vocabulary_size = vocabulary_size + 3 -# train_complexifier(models[0], models[1], models[2]) + generator = mygpt.MyGPT( + vocabulary_size=generator_vocabulary_size, + dim_model=args.dim_model, + dim_keys=args.dim_keys, + dim_hidden=args.dim_hidden, + nb_heads=args.nb_heads, + nb_blocks=args.nb_blocks, + causal=True, + dropout=args.dropout, + ).to(main_device) -# exit(0) + generator.main_test_accuracy = 0.0 -###################################################################### + for n_epoch in range(args.nb_epochs): + one_generator_epoch( + generator, + quiz_machine=quiz_machine, + models=models, + w_quizzes=True, + local_device=main_device, + ) -token_prolog_0 = vocabulary_size + 0 -token_prolog_1 = vocabulary_size + 1 -token_prolog_2 = vocabulary_size + 2 -generator_vocabulary_size = vocabulary_size + 3 - -generator = mygpt.MyGPT( - vocabulary_size=generator_vocabulary_size, - dim_model=args.dim_model, - dim_keys=args.dim_keys, - dim_hidden=args.dim_hidden, - nb_heads=args.nb_heads, - nb_blocks=args.nb_blocks, - causal=True, - dropout=args.dropout, -).to(main_device) - -generator.main_test_accuracy = 0.0 - -for n_epoch in range(25): - one_generator_epoch( - generator, - quiz_machine=quiz_machine, - models=models, - w_quizzes=True, - local_device=main_device, - ) + filename = f"generator.pth" + torch.save( + (generator.state_dict(), generator.main_test_accuracy), + os.path.join(args.result_dir, filename), + ) + log_string(f"wrote {filename}") - c_quizzes = generate_c_quizz_with_generator( - generator, quiz_machine, args.batch_size - ) + c_quizzes = generate_c_quizz_with_generator( + generator, quiz_machine, args.batch_size + ) - seq_logproba = quiz_machine.models_logprobas( - models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) - ) + quiz_machine.models_logprobas( - models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1) - ) + seq_logproba = quiz_machine.models_logprobas( + models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) + ) + quiz_machine.models_logprobas( + models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1) + ) - print(seq_logproba.exp()) + print(seq_logproba.exp()) + comments = [] -one_generator_epoch( - generator, - quiz_machine=quiz_machine, - models=models, - w_quizzes=False, - local_device=main_device, -) + for l in seq_logproba: + comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l])) + + filename = f"generator_batch_{n_epoch:04d}.png" + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, filename, c_quizzes, comments=comments + ) + log_string(f"wrote {filename}") -exit(0) + one_generator_epoch( + generator, + quiz_machine=quiz_machine, + models=models, + w_quizzes=False, + local_device=main_device, + ) + exit(0) ###################################################################### -- 2.39.5