From 982438ec146974f415072ff98523503fc8721538 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 14 Jul 2024 18:20:01 +0200 Subject: [PATCH] Update. --- main.py | 112 ++++++++++++++++++++++++------------------------ quiz_machine.py | 15 ++++--- 2 files changed, 66 insertions(+), 61 deletions(-) diff --git a/main.py b/main.py index 6c4099f..7ba5193 100755 --- a/main.py +++ b/main.py @@ -48,6 +48,10 @@ parser.add_argument("--nb_train_samples", type=int, default=None) parser.add_argument("--nb_test_samples", type=int, default=None) +parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None) + +parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None) + parser.add_argument("--learning_rate", type=float, default=5e-4) ######################################## @@ -78,7 +82,7 @@ parser.add_argument("--gpus", type=str, default="all") parser.add_argument("--nb_gpts", type=int, default=5) -parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975) +parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9) parser.add_argument("--proba_understands", type=float, default=0.99) @@ -366,20 +370,28 @@ def one_epoch(model, quiz_machine, local_device=main_device): ###################################################################### +# This is the key routine that decides what generated quizzes to keep + + +def compute_valid_quizzes(token_logprobas): + warnings.warn("validation with uniform constraints", RuntimeWarning) + l = token_logprobas.min(dim=-1).values.sort(dim=-1).values + return (l[:, 0] < math.log(0.1)) & (l[:, 1] > math.log(0.5)) + -def standard_validity(logproba): - l = logproba.sort(dim=-1).values +def compute_valid_quizzes_(token_logprobas): + l = token_logprobas.sum(dim=-1).sort(dim=-1).values return (l[:, 0] < math.log(args.proba_not_understands)) & ( l[:, 1] > math.log(args.proba_understands) ) -def valid_quizzes_and_logprobas(recorded, criteria): +def extract_valid_quizzes_and_logprobas(recorded): validated_quizzes, validated_logprobas = [], [] - for q, lp in recorded: - validated_indices = criteria(lp) - validated_quizzes.append(q[validated_indices]) - validated_logprobas.append(lp[validated_indices]) + for quizzes, token_logprobas in recorded: + validated_indices = compute_valid_quizzes(token_logprobas) + validated_quizzes.append(quizzes[validated_indices]) + validated_logprobas.append(token_logprobas[validated_indices]) if len(validated_quizzes) > 0: return torch.cat(validated_quizzes, dim=0), torch.cat( @@ -411,12 +423,13 @@ def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100): c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] if c_quizzes.size(0) > 0: - logproba = quiz_machine.logproba_of_solutions(models, c_quizzes) - recorded_quizzes_logprobas.append((c_quizzes, logproba)) + token_logproba = quiz_machine.solution_token_logprobas(models, c_quizzes) + recorded_quizzes_logprobas.append((c_quizzes, token_logproba)) - validated_quizzes, validated_logprobas = valid_quizzes_and_logprobas( - recorded_quizzes_logprobas, standard_validity - ) + ( + validated_quizzes, + validated_logprobas, + ) = extract_valid_quizzes_and_logprobas(recorded_quizzes_logprobas) if validated_quizzes is not None: nb_validated = validated_quizzes.size(0) @@ -433,19 +446,6 @@ def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100): validated_quizzes[nb_for_train:nb_to_create], for_train=False ) - ###################################################################### - # save the log probas - - file_name = os.path.join( - args.result_dir, f"culture_c_quiz_all_{n_epoch:04d}_logp.dat" - ) - - with open(file_name, "w") as logp_file: - for _, ll in recorded_quizzes_logprobas: - for l in ll: - s = " ".join([str(x.item()) for x in l]) - logp_file.write(s + "\n") - ###################################################################### # save images with their logprobas @@ -454,12 +454,12 @@ def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100): if vq.size(0) > 0: prefix = f"culture_c_quiz_{n_epoch:04d}" - - file_name = os.path.join(args.result_dir, prefix + "_logp.dat") - with open(file_name, "w") as logp_file: - for l in vl: - s = " ".join([str(x.item()) for x in l]) - logp_file.write(s + "\n") + filename = os.path.join(args.result_dir, prefix + "_logp.pth") + torch.save(vl, filename) + # with open(file_name, "w") as logp_file: + # for l in vl: + # s = " ".join([str(x.item()) for x in l]) + # logp_file.write(s + "\n") quiz_machine.save_quiz_illustrations(args.result_dir, prefix, vq) @@ -574,11 +574,14 @@ if args.max_percents_of_test_in_train >= 0: ###################################################################### -nb_new_c_quizzes_for_train = args.nb_train_samples // 50 -nb_new_c_quizzes_for_test = args.nb_test_samples // 50 +if args.nb_new_c_quizzes_for_train is None: + args.nb_new_c_quizzes_for_train = args.nb_train_samples // 50 + +if args.nb_new_c_quizzes_for_test is None: + args.nb_new_c_quizzes_for_test = args.nb_test_samples // 50 log_string( - f"nb_new_c_quizzes_for_train {nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {nb_new_c_quizzes_for_test}" + f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}" ) ###################################################################### @@ -586,12 +589,8 @@ log_string( if args.dirty_debug: args.accuracy_to_make_c_quizzes = 0.0 args.nb_gpts = 2 - nb_new_c_quizzes_for_train = 100 - nb_new_c_quizzes_for_test = 10 - - def standard_validity(logproba): - l = logproba.sort(dim=-1).values - return l[:, 0] < math.log(0.5) + args.nb_new_c_quizzes_for_train = 100 + args.nb_new_c_quizzes_for_test = 10 ###################################################################### @@ -602,6 +601,22 @@ for n_epoch in range(args.nb_epochs): cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models]) log_string(f"current_test_accuracies {cta}") + ################################################## + # If all the models are good enough, generate new quizzes and + # re-compute the test errors + + if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes: + create_c_quizzes( + models, + quiz_machine, + nb_for_train=args.nb_new_c_quizzes_for_train, + nb_for_test=args.nb_new_c_quizzes_for_test, + ) + + filename = "c_quizzes.pth" + quiz_machine.save_c_quizzes(os.path.join(args.result_dir, filename)) + log_string(f"wrote {filename}") + ################################################## # Select, improve, and eval the worst model @@ -640,20 +655,5 @@ for n_epoch in range(args.nb_epochs): for model in weakest_models: quiz_machine.renew_w_quizzes(model, args.nb_train_samples) - ################################################## - # If all the models are good enough, generate new quizzes and - # re-compute the test errors - - if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes: - create_c_quizzes( - models, - quiz_machine, - nb_for_train=nb_new_c_quizzes_for_train, - nb_for_test=nb_new_c_quizzes_for_test, - ) - - filename = "c_quizzes.pth" - quiz_machine.save_c_quizzes(os.path.join(args.result_dir, filename)) - log_string(f"wrote {filename}") ###################################################################### diff --git a/quiz_machine.py b/quiz_machine.py index c49ecf2..bc468d3 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -450,9 +450,13 @@ class QuizMachine: ###################################################################### - def logproba_of_solutions(self, models, c_quizzes): + def solution_token_logprobas(self, models, c_quizzes): logproba = c_quizzes.new_zeros( - c_quizzes.size(0), len(models), device=self.device, dtype=torch.float32 + c_quizzes.size(0), + len(models), + c_quizzes.size(1), + device=self.device, + dtype=torch.float32, ) for model in models: @@ -466,11 +470,12 @@ class QuizMachine: input = input.to(self.device) ar_mask = self.make_ar_mask(input) output = model(mygpt.BracketedSequence(input)).x - ce = ( - F.cross_entropy(output.transpose(1, 2), input, reduction="none") + l[:, model.id] = ( + -F.cross_entropy( + output.transpose(1, 2), input, reduction="none" + ) * ar_mask ) - l[:, model.id] = -ce.sum(dim=-1) model.train(t) -- 2.39.5