From fb685ce7971beb6485c94c7fb43312ba0cdf8d41 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 30 Jul 2024 19:41:04 +0200 Subject: [PATCH] Update. --- main.py | 119 ++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 115 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 455aa1c..19a3c29 100755 --- a/main.py +++ b/main.py @@ -49,6 +49,8 @@ parser.add_argument("--batch_size", type=int, default=None) parser.add_argument("--physical_batch_size", type=int, default=None) +parser.add_argument("--inference_batch_size", type=int, default=None) + parser.add_argument("--nb_train_samples", type=int, default=None) parser.add_argument("--nb_test_samples", type=int, default=None) @@ -157,6 +159,7 @@ assert not args.grids_science_tasks or ( default_args = { "model": "37M", "batch_size": 25, + "inference_batch_size": 100, "nb_train_samples": 100000, "nb_test_samples": 10000, } @@ -336,7 +339,7 @@ if not args.resume: quiz_machine = quiz_machine.QuizMachine( problem=problem, - batch_size=args.physical_batch_size, + batch_size=args.inference_batch_size, result_dir=args.result_dir, logger=log_string, device=main_device, @@ -670,6 +673,106 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 ###################################################################### +def train_complexifier(model_gen, model_pred1, model_pred2): + samples = [] + perf = [] + + optimizer = torch.optim.Adam(model_gen.parameters(), lr=args.learning_rate) + + nb_train_samples, acc_train_loss = 0, 0.0 + + for n_epoch in range(args.nb_epochs): + for b in range(args.nb_train_samples // args.batch_size): + while sum([x.size(0) for x in samples]) < args.batch_size: + c_quizzes = quiz_machine.generate_c_quizzes( + args.inference_batch_size, + model_for_generation=model_gen, + procedure=c_quizzes_procedure, + ) + to_keep = quiz_machine.problem.trivial(c_quizzes) == False + c_quizzes = c_quizzes[to_keep] + if c_quizzes.size(0) > 0: + seq_logproba = quiz_machine.models_logprobas( + [model_pred1, model_pred2], + c_quizzes, + ("A", "f_A", "B", "f_B"), + (0, 0, 0, 1), + ) + quiz_machine.models_logprobas( + [model_pred1, model_pred2], + c_quizzes, + ("f_A", "A", "f_B", "B"), + (0, 0, 0, 1), + ) + probas = seq_logproba.exp() + to_keep = (probas[:, model_pred1.id] >= args.proba_understands) & ( + probas[:, model_pred2.id] <= args.proba_not_understands + ) + log_string( + f"generating {to_keep.long().sum()} / {c_quizzes.size(0)}" + ) + c_quizzes = c_quizzes[to_keep] + if c_quizzes.size(0): + samples.append(c_quizzes) + + log_string(f"full batch {sum([x.size(0) for x in samples])}") + + x = torch.cat(samples, dim=0) + + input = x[: args.batch_size] + samples = [x[args.batch_size :]] + + # ------------------- + + seq_logproba = quiz_machine.models_logprobas( + [model_pred1, model_pred2], + input, + ("A", "f_A", "B", "f_B"), + (0, 0, 0, 1), + ) + quiz_machine.models_logprobas( + [model_pred1, model_pred2], + input, + ("f_A", "A", "f_B", "B"), + (0, 0, 0, 1), + ) + + comments = [] + + for l in seq_logproba: + comments.append( + f"proba {l[model_pred1.id].exp().item():.02f} {l[model_pred2.id].exp().item():.02f}" + ) + + filename = f"batch_{n_epoch:04d}_{b:04d}.png" + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, filename, input, comments=comments + ) + log_string(f"wrote {filename}") + + # ------------------------ + + input = input.to(main_device) + + if nb_train_samples % args.batch_size == 0: + optimizer.zero_grad() + + output = model_gen(mygpt.BracketedSequence(input)).x + loss = F.cross_entropy(output.transpose(1, 2), input) + acc_train_loss += loss.item() * input.size(0) + nb_train_samples += input.size(0) + + loss.backward() + + if nb_train_samples % args.batch_size == 0: + optimizer.step() + + train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) + + log_string(f"train_perplexity {n_epoch} model ae {train_perplexity}") + + +###################################################################### + + def train_autoencoder(): model = mygpt.MyGPT( vocabulary_size=vocabulary_size, @@ -769,9 +872,9 @@ def train_autoencoder(): return model -if args.autoencoder_dim > 0: - ae = train_autoencoder() - exit(0) +# if args.autoencoder_dim > 0: +# ae = train_autoencoder() +# exit(0) ###################################################################### @@ -864,6 +967,14 @@ if args.dirty_debug: ###################################################################### +# DIRTY TEST + +# train_complexifier(models[0], models[1], models[2]) + +# exit(0) + +###################################################################### + for n_epoch in range(current_epoch, args.nb_epochs): state = {"current_epoch": n_epoch} filename = "state.pth" -- 2.20.1