From e15bb32622feb3751d1302af1a33a5fbf95d3ea6 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 13 Aug 2024 00:03:42 +0200 Subject: [PATCH] Update. --- main.py | 629 +++++++----------------------------------------- quiz_machine.py | 81 +++---- 2 files changed, 123 insertions(+), 587 deletions(-) diff --git a/main.py b/main.py index e516a77..0a79323 100755 --- a/main.py +++ b/main.py @@ -391,7 +391,7 @@ def run_tests(model, quiz_machine, local_device=main_device): nb_samples_accumulated = 0 full_input, full_mask_loss = quiz_machine.data_input( - model, args.nb_test_samples + args.nb_test_samples, model.test_c_quiz_bags ) src = zip( full_input.split(args.batch_size), full_mask_loss.split(args.batch_size) @@ -441,7 +441,9 @@ def one_epoch(model, quiz_machine, local_device=main_device): hard_w_quizzes = [] - full_input, full_mask_loss = quiz_machine.data_input(model, args.nb_train_samples) + full_input, full_mask_loss = quiz_machine.data_input( + args.nb_train_samples, model.train_c_quiz_bags + ) src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)) for input, mask_loss in tqdm.tqdm( @@ -528,17 +530,23 @@ def save_additional_results(model, models, science_w_quizzes): # This is nb_quizzes x nb_models - seq_logproba = quiz_machine.models_logprobas( - models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) + quiz_machine.models_logprobas( - models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) + l = [ + quiz_machine.models_logprobas( + model, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + + quiz_machine.models_logprobas( + model, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + for model in models + ] + + seq_logprobas = torch.cat([x[None, :] for x in l]) - probas = seq_logproba.exp() + probas = seq_logprobas.exp() comments = [] - for l in seq_logproba: + for l in seq_logprobas: comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l])) ## @@ -616,18 +624,26 @@ def save_additional_results(model, models, science_w_quizzes): ###################################################################### -def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100): - nb_to_validate = nb_for_train + nb_for_test - nb_to_generate_per_iteration = max(args.physical_batch_size, nb_to_validate) - nb_validated = 0 +def model_proba_solutions(m, quizzes): + l = quiz_machine.models_logprobas( + m, quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + quiz_machine.models_logprobas( + m, quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + + return l.exp() + - recorded_validated = [] +def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): + nb_validated, nb_to_validate = 0, nb_for_train + nb_for_test + nb_to_generate_per_iteration = nb_to_validate start_time = time.perf_counter() - nb_validated_per_model = torch.zeros(len(models), dtype=torch.int64) + for model in models: + model.recorded_c_quizzes = [] - while nb_validated_per_model.sum() < nb_to_validate: + while nb_validated < nb_to_validate: model_for_generation = models[torch.randint(len(models), (1,)).item()] # We generate quizzes with a procedure that injects some @@ -646,80 +662,48 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 c_quizzes = c_quizzes[to_keep] - # This is nb_quizzes x nb_models + # Compute the responses of all the models on the c_quizzes, + # and their proba estimates of their responses solved_c_quizzes = c_quizzes[:, None, :].expand(-1, len(models), -1).clone() - seq_logproba = torch.zeros( + proba_own_solution = torch.zeros( c_quizzes.size(0), len(models), device=solved_c_quizzes.device ) - for m in models: - ( - solved_c_quizzes[:, m.id], - _, - seq_logproba[:, m.id], - ) = quiz_machine.predict( - m, - solved_c_quizzes[:, m.id], + for model in models: + (solved_c_quizzes[:, model.id], _, _) = quiz_machine.predict( + model, + solved_c_quizzes[:, model.id], struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1), ) - #!!!!!!!!!!!!!!!!!!!! - for m in range(seq_logproba.size(1)): - l = quiz_machine.models_logprobas( - [models[m]], - solved_c_quizzes[:, m, :], - ("A", "f_A", "B", "f_B"), - (0, 0, 0, 1), - (0, 0, 0, 0), - ) - for s in range(seq_logproba.size(0)): - print("DEBUG", seq_logproba[s, m].item(), l[s, 0].item()) - exit(0) - #!!!!!!!!!!!!!!!!!!!!!!!!! + u = model_proba_solutions(model, solved_c_quizzes[:, model.id]) - # FINISH + proba_own_solution[:, model.id] = u - seq_logproba = quiz_machine.models_logprobas( - models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) + quiz_machine.models_logprobas( - models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) - - probas = seq_logproba.exp() - - nb_succeed = (probas >= args.proba_understands).long().sum(dim=1) - nb_fail = (probas <= args.proba_not_understands).long().sum(dim=1) - - to_keep = ( - # (nb_succeed + nb_fail == probas.size(1)) - (nb_succeed >= args.min_succeed_to_validate) - & (nb_fail >= 1) - & (nb_fail <= args.max_fail_to_validate) - ) - - c_quizzes = c_quizzes[to_keep] - - if c_quizzes.size(0) > 0: - nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0) - recorded_validated.append(c_quizzes) - nb_validated = c_quizzes.size(0) - else: - nb_validated = 0 + # Now for every model not confident of its response, we pick + # the most consistent from a model which is confident - total_nb_validated = nb_validated_per_model.sum().item() + for s in range(proba_own_solution.size(0)): + dont_get_it = proba_own_solution[s, :] < args.proba_understands + if not dont_get_it.all(): + for model in models: + if dont_get_it[model.id]: + proba_other_solutions = model_proba_solutions( + model, solved_c_quizzes[s] + ) + proba_other_solutions[dont_get_it] = -1 + i = proba_other_solutions.argmax() + model.recorded_c_quizzes.append(solved_c_quizzes[s, i]) + nb_validated += 1 duration = time.perf_counter() - start_time - if total_nb_validated > 0: - if total_nb_validated < nb_to_validate: - d = ( - (nb_to_validate - total_nb_validated) - * duration - / total_nb_validated - ) + if nb_validated > 0: + if nb_validated < nb_to_validate: + d = (nb_to_validate - nb_validated) * duration / nb_validated e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime( "%a %H:%M" ) @@ -729,320 +713,44 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 e = "???" log_string( - f"keep c_quizzes model {model_for_generation.id} validated {nb_validated} / {nb_to_generate_per_iteration} ({100*nb_validated/nb_to_generate_per_iteration:.02f}%) nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)" - ) - - validated_quizzes = torch.cat(recorded_validated, dim=0) - - ###################################################################### - # store the new c_quizzes which have been validated - - v_train = validated_quizzes[:nb_for_train] - quiz_machine.store_c_quizzes(v_train, for_train=True) - - v_test = validated_quizzes[nb_for_train:nb_to_validate] - quiz_machine.store_c_quizzes(v_test, for_train=False) - - ###################################################################### - # save images - - vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]] - - if vq.size(0) > 0: - seq_logproba = quiz_machine.models_logprobas( - models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) + quiz_machine.models_logprobas( - models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) - - probas = seq_logproba.exp() - - comments = [] - - for l in seq_logproba: - comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l])) - - filename = f"culture_c_quiz_{n_epoch:04d}.png" - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, filename, vq, comments=comments - ) - - -###################################################################### - -# The generator is very similar to a "solving GPT" except that it -# deals with quizzes prologued with one token per solving GPT that -# indicates if the said model solves it or not. -# -# There are three levels of solving 0->proba<=proba_not_understands, -# 2->proba>=proba_understands and 1 otherwise. - - -def generate_c_quizzes_with_generator(generator, quiz_machine, nb): - generator.to(main_device) - - struct = ("A", "f_A", "B", "f_B") - - c_quizzes = quiz_machine.problem.create_empty_quizzes(nb, struct=struct) - ar_mask = quiz_machine.make_quiz_mask(c_quizzes, struct, (1, 1, 1, 1)) - - i = F.one_hot( - torch.randint(args.nb_gpts, (c_quizzes.size(0),)), - num_classes=args.nb_gpts, - ) - - prologs_c_quizzes = token_prolog_0 * i + token_prolog_2 * (1 - i) - prologs_ar_mask = ar_mask.new_zeros(ar_mask.size(0), prologs_c_quizzes.size(1)) - - prologued_c_quizzes = torch.cat([prologs_c_quizzes, c_quizzes], dim=1).to( - main_device - ) - prologued_ar_mask = torch.cat([prologs_ar_mask, ar_mask], dim=1).to(main_device) - - seq_logproba = torch.zeros( - prologued_c_quizzes.size(0), device=prologued_c_quizzes.device - ) - - generator.temperature = args.temperature_hot - - with torch.autograd.no_grad(): - t = generator.training - generator.eval() - - one_batch_masked_inplace_autoregression( - generator, - prologued_c_quizzes, - prologued_ar_mask, - seq_logproba, - deterministic_synthesis=False, + f"keep c_quizzes model {model_for_generation.id} validated {nb_validated} / {nb_to_generate_per_iteration} ({100*nb_validated/nb_to_generate_per_iteration:.02f}%) nb_accumulated {nb_validated} / {nb_to_validate} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h)" ) - generator.train(t) - - generator.reset_transformations() - - prologued_c_quizzes = ( - prologued_c_quizzes * (prologued_c_quizzes < vocabulary_size).long() - ) - - c_quizzes = prologued_c_quizzes[:, prologs_c_quizzes.size(1) :] - - return c_quizzes.to("cpu"), prologs_c_quizzes.to("cpu") - - -def batches_for_generator(generator, quiz_machine, models, fraction_w_quizzes=1.0): - samples = [] - - for _ in range(args.nb_train_samples // args.batch_size): - while sum([x.size(0) for x in samples]) < args.batch_size: - # Generate a bunch of quizzes - - if torch.rand(1).item() <= fraction_w_quizzes: - # Either we start with the world quizzes - c_quizzes = quiz_machine.problem.generate_w_quizzes( - args.batch_size, progress_bar=False - ) - else: - # Or we use the generator itself to generate them - c_quizzes, _ = generate_c_quizzes_with_generator( - generator, quiz_machine, args.batch_size - ) - - # We remove the trivial ones - to_keep = quiz_machine.problem.trivial(c_quizzes) == False - c_quizzes = c_quizzes[to_keep] - - # If there are remaining ones, we compute the true prolog - # that indicates how the GPTs solve it - - if c_quizzes.size(0) > 0: - seq_logproba = quiz_machine.models_logprobas( - models, - c_quizzes, - ("A", "f_A", "B", "f_B"), - (0, 0, 0, 1), - (0, 0, 1, 0), - ) + quiz_machine.models_logprobas( - models, - c_quizzes, - ("f_A", "A", "f_B", "B"), - (0, 0, 0, 1), - (0, 0, 1, 0), - ) - - probas = seq_logproba.exp() - - u0 = probas <= args.proba_not_understands - u2 = probas >= args.proba_understands - u1 = (u0 | u2) == False - - prologs = ( - (u0.long() * token_prolog_0) - + (u1.long() * token_prolog_1) - + (u2.long() * token_prolog_2) - ) - - prologued_c_quizzes = torch.cat([prologs, c_quizzes], dim=1) - - # nb_u2 = u2.long().sum(dim=1) - # nb_u0 = u0.long().sum(dim=1) - # prologued_c_quizzes = prologued_c_quizzes[(nb_u2 >= 1) & (nb_u0 >= 1)] - - if prologued_c_quizzes.size(0) > 0: - samples.append(prologued_c_quizzes) - - # Now we yield a batch - - x = torch.cat(samples, dim=0) - samples = [x[args.batch_size :]] - - yield x[: args.batch_size] - - -def one_generator_epoch( - generator, quiz_machine, models, fraction_w_quizzes, local_device=main_device -): - model.to(local_device).train() - - optimizer = torch.optim.Adam(generator.parameters(), lr=args.learning_rate) - - nb_train_samples, acc_train_loss = 0, 0.0 - - src = batches_for_generator( - generator=generator, - quiz_machine=quiz_machine, - models=models, - fraction_w_quizzes=fraction_w_quizzes, - ) - - for input in tqdm.tqdm( - src, - dynamic_ncols=True, - desc="training", - total=args.nb_train_samples // args.batch_size, - ): - input = input.to(local_device) - - if nb_train_samples % args.batch_size == 0: - optimizer.zero_grad() - - targets = input - - output = generator(mygpt.BracketedSequence(input)).x - loss = F.cross_entropy(output.transpose(1, 2), targets) - acc_train_loss += loss.item() * input.size(0) - nb_train_samples += input.size(0) - - loss.backward() - - if nb_train_samples % args.batch_size == 0: - optimizer.step() - - train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) - - log_string(f"train_perplexity {n_epoch} generator - {train_perplexity}") - - generator.to(main_device) - - -###################################################################### - - -def train_complexifier(model_gen, model_pred1, model_pred2): - samples = [] - perf = [] + for model in models: + new_bag = torch.cat([q[None, :] for q in model.recorded_c_quizzes], dim=0) - optimizer = torch.optim.Adam(model_gen.parameters(), lr=args.learning_rate) + if new_bag.size(0) > 0: + n = (new_bag.size(0) * nb_for_train) // (nb_for_train + nb_for_test) + if n > 0: + model.train_c_quiz_bags.append(new_bag[:n]) + if n < new_bag.size(0): + model.test_c_quiz_bags.append(new_bag[:n]) - nb_train_samples, acc_train_loss = 0, 0.0 + vq = new_bag[:128] - for n_epoch in range(args.nb_epochs): - for b in range(args.nb_train_samples // args.batch_size): - while sum([x.size(0) for x in samples]) < args.batch_size: - c_quizzes = quiz_machine.generate_c_quizzes( - args.inference_batch_size, - model_for_generation=model_gen, - procedure=c_quizzes_procedure, - ) - to_keep = quiz_machine.problem.trivial(c_quizzes) == False - c_quizzes = c_quizzes[to_keep] - if c_quizzes.size(0) > 0: - seq_logproba = quiz_machine.models_logprobas( - [model_pred1, model_pred2], - c_quizzes, - ("A", "f_A", "B", "f_B"), - (0, 0, 0, 1), - ) + quiz_machine.models_logprobas( - [model_pred1, model_pred2], - c_quizzes, - ("f_A", "A", "f_B", "B"), - (0, 0, 0, 1), - ) - probas = seq_logproba.exp() - to_keep = (probas[:, model_pred1.id] >= args.proba_understands) & ( - probas[:, model_pred2.id] <= args.proba_not_understands - ) - log_string( - f"generating {to_keep.long().sum()} / {c_quizzes.size(0)}" - ) - c_quizzes = c_quizzes[to_keep] - if c_quizzes.size(0): - samples.append(c_quizzes) - - log_string(f"full batch {sum([x.size(0) for x in samples])}") - - x = torch.cat(samples, dim=0) - - input = x[: args.batch_size] - samples = [x[args.batch_size :]] - - # ------------------- - - seq_logproba = quiz_machine.models_logprobas( - [model_pred1, model_pred2], - input, - ("A", "f_A", "B", "f_B"), - (0, 0, 0, 1), + seq_logprobas = quiz_machine.models_logprobas( + models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) ) + quiz_machine.models_logprobas( - [model_pred1, model_pred2], - input, - ("f_A", "A", "f_B", "B"), - (0, 0, 0, 1), + models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) ) + probas = seq_logprobas.exp() + comments = [] - for l in seq_logproba: + for l in seq_logprobas: comments.append( - f"proba {l[model_pred1.id].exp().item():.02f} {l[model_pred2.id].exp().item():.02f}" + "proba " + " ".join([f"{x.exp().item():.02f}" for x in l]) ) - filename = f"batch_{n_epoch:04d}_{b:04d}.png" + filename = f"culture_c_quiz_{n_epoch:04d}.png" quiz_machine.problem.save_quizzes_as_image( - args.result_dir, filename, input, comments=comments + args.result_dir, filename, vq, comments=comments ) - log_string(f"wrote {filename}") - - # ------------------------ - - input = input.to(main_device) - - if nb_train_samples % args.batch_size == 0: - optimizer.zero_grad() - - output = model_gen(mygpt.BracketedSequence(input)).x - loss = F.cross_entropy(output.transpose(1, 2), input) - acc_train_loss += loss.item() * input.size(0) - nb_train_samples += input.size(0) - - loss.backward() - if nb_train_samples % args.batch_size == 0: - optimizer.step() - - train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) - - log_string(f"train_perplexity {n_epoch} model ae {train_perplexity}") + log_string( + f"nb_c_quizzes model {model.id} train {sum([q.size(0) for q in model.train_c_quiz_bags ])} test {sum([q.size(0) for q in model.test_c_quiz_bags ])}" + ) ###################################################################### @@ -1072,7 +780,8 @@ for k in range(args.nb_gpts): ).to(main_device) model.id = k - model.c_quiz_bags = [] + model.train_c_quiz_bags = [] + model.test_c_quiz_bags = [] if args.schedule_free: model.optimizer = schedulefree.AdamWScheduleFree( @@ -1087,29 +796,6 @@ for k in range(args.nb_gpts): ###################################################################### -if args.test == "quant": - nb_bits = 8 - for model in models: - model.trunk.insert( - 12, - mygpt.CacheWrapper( - mygpt.RandomBypass( - nn.Sequential( - nn.Linear(args.dim_model, nb_bits), - mygpt.BSQ(nb_bits), - nn.Linear(nb_bits, args.dim_model), - ), - 0.1, - ) - ), - ) - - print(model) - exit(0) - - -###################################################################### - current_epoch = 0 if args.resume: @@ -1170,153 +856,6 @@ if args.dirty_debug: ###################################################################### -if args.test == "tsne": - model = models[0] - - quizzes = [] - labels = [] - nb_samples_per_task = 1000 - - for n, t in enumerate(args.grids_world_tasks.split(",")): - quizzes.append( - quiz_machine.problem.generate_w_quizzes(nb_samples_per_task, [t]) - ) - labels.append(torch.full((quizzes[-1].size(0),), n)) - - quizzes = torch.cat(quizzes, dim=0) - labels = torch.cat(labels, dim=0) - - with torch.autograd.no_grad(): - model.eval().to(main_device) - record = [] - for input, targets in zip( - quizzes.split(args.batch_size), labels.split(args.batch_size) - ): - input = input.to(main_device) - bs = mygpt.BracketedSequence(input) - bs = mygpt.BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) - bs = model.embedding(bs) - bs = model.trunk[args.nb_blocks // 2](bs) - record.append((bs.x.to("cpu"), targets)) - - x = torch.cat([x for x, y in record], dim=0).flatten(1) - y = torch.cat([y for x, y in record], dim=0) - - print(f"{x.size()=} {y.size()=}") - # torch.save((x,y), "/tmp/embed.pth") - # exit(0) - - from sklearn.manifold import TSNE - - x_np = x.numpy() - z_np = TSNE(n_components=2, perplexity=50).fit_transform(x_np) - z = torch.from_numpy(z_np) - - print(f"{z.size()=}") - - with open("/tmp/result.dat", "w") as f: - for k in range(z.size(0)): - f.write(f"{y[k]} {z[k,0]} {z[k,1]}\n") - - exit(0) - -###################################################################### - -if args.test == "generator": - token_prolog_0 = vocabulary_size + 0 - token_prolog_1 = vocabulary_size + 1 - token_prolog_2 = vocabulary_size + 2 - generator_vocabulary_size = vocabulary_size + 3 - - generator = mygpt.MyGPT( - vocabulary_size=generator_vocabulary_size, - dim_model=args.dim_model, - dim_keys=args.dim_keys, - dim_hidden=args.dim_hidden, - nb_heads=args.nb_heads, - nb_blocks=args.nb_blocks, - compute_attzero=compute_causal_attzero, - dropout=args.dropout, - ).to(main_device) - - generator.main_test_accuracy = 0.0 - - filename = f"generator.pth" - - try: - d = torch.load(os.path.join(args.result_dir, filename)) - generator.load_state_dict(d[0]) - generator.main_test_accuracy = d[1] - log_string(f"successfully loaded {filename}") - except FileNotFoundError: - log_string(f"cannot find {filename}") - pass - - for n_epoch in range(args.nb_epochs): - one_generator_epoch( - generator, - quiz_machine=quiz_machine, - models=models, - fraction_w_quizzes=1 if n_epoch < 25 else 0.5, - local_device=main_device, - ) - - filename = f"generator.pth" - torch.save( - (generator.state_dict(), generator.main_test_accuracy), - os.path.join(args.result_dir, filename), - ) - log_string(f"wrote {filename}") - - c_quizzes, prologs = generate_c_quizzes_with_generator( - generator, quiz_machine, args.batch_size - ) - - seq_logproba = quiz_machine.models_logprobas( - models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) + quiz_machine.models_logprobas( - models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) - - probas = seq_logproba.exp() - - u0 = probas <= args.proba_not_understands - u2 = probas >= args.proba_understands - u1 = (u0 | u2) == False - - predicted_prologs = ( - (u0.long() * token_prolog_0) - + (u1.long() * token_prolog_1) - + (u2.long() * token_prolog_2) - ) - - comments = [] - - nb_errors = (predicted_prologs != prologs).long().sum() - nb_total = prologs.numel() - - log_string(f"generator_error {nb_errors} / {nb_total}") - - def readable(prologs): - return (prologs == token_prolog_1) + 2 * (prologs == token_prolog_2) - - for aa, ee, ff in zip(probas, readable(predicted_prologs), readable(prologs)): - sa = "prolog " + " ".join( - [f"{e.item()}/{f.item()}" for e, f in zip(ee, ff)] - ) - sp = "proba " + " ".join([f"{p.item():.02f}" for p in aa]) - comments.append(sa + "\n" + sp) - - filename = f"generator_batch_{n_epoch:04d}.png" - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, filename, c_quizzes, comments=comments - ) - log_string(f"wrote {filename}") - - exit(0) - -###################################################################### - for n_epoch in range(current_epoch, args.nb_epochs): state = {"current_epoch": n_epoch} filename = "state.pth" @@ -1336,8 +875,8 @@ for n_epoch in range(current_epoch, args.nb_epochs): record_new_c_quizzes( models, quiz_machine, - nb_for_train=args.nb_new_c_quizzes_for_train, - nb_for_test=args.nb_new_c_quizzes_for_test, + args.nb_new_c_quizzes_for_train, + args.nb_new_c_quizzes_for_test, ) filename = "c_quizzes.pth" diff --git a/quiz_machine.py b/quiz_machine.py index b2287b8..1fe2e94 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -28,7 +28,7 @@ def one_batch_masked_inplace_autoregression( model, input, ar_mask, - acc_seq_logproba, + acc_seq_logprobas, deterministic_synthesis=False, ): if input.size(0) == 0: @@ -53,7 +53,7 @@ def one_batch_masked_inplace_autoregression( all_n = torch.arange(t_next.size(0)) - acc_seq_logproba += ar_mask[:, s] * logits.log_softmax(dim=1)[all_n, t_next] + acc_seq_logprobas += ar_mask[:, s] * logits.log_softmax(dim=1)[all_n, t_next] input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] @@ -107,7 +107,7 @@ class QuizMachine: model, input, ar_mask, - seq_logproba, + seq_logprobas, progress_bar_desc=None, ): assert input.size() == ar_mask.size() @@ -115,7 +115,7 @@ class QuizMachine: batches = zip( input.split(self.batch_size), ar_mask.split(self.batch_size), - seq_logproba.split(self.batch_size), + seq_logprobas.split(self.batch_size), ) if progress_bar_desc is not None: @@ -130,12 +130,12 @@ class QuizMachine: t = model.training model.eval() - for input, ar_mask, seq_logproba in batches: + for input, ar_mask, seq_logprobas in batches: one_batch_masked_inplace_autoregression( model=model, input=input, ar_mask=ar_mask, - acc_seq_logproba=seq_logproba, + acc_seq_logprobas=seq_logprobas, deterministic_synthesis=False, ) @@ -143,9 +143,9 @@ class QuizMachine: ###################################################################### - def data_input(self, model, nb_samples): - if len(model.c_quiz_bags) > 0: - c_quizzes = torch.cat(model.c_quiz_bags, dim=0) + def data_input(self, nb_samples, c_quiz_bags): + if len(c_quiz_bags) > 0: + c_quizzes = torch.cat(c_quiz_bags, dim=0) if c_quizzes.size(0) > nb_samples // 2: i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2] @@ -191,23 +191,23 @@ class QuizMachine: ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, mask=mask) result = quizzes * (1 - ar_mask) - seq_logproba = torch.zeros(quizzes.size(0), device=self.device) + seq_logprobas = torch.zeros(quizzes.size(0), device=self.device) self.autoregression( model=model, input=result, ar_mask=ar_mask, - seq_logproba=seq_logproba, + seq_logprobas=seq_logprobas, progress_bar_desc="accuracy", ) correct = (result == quizzes).min(dim=1).values.long() - result = result.to("cpu") - correct = correct.to("cpu") - seq_logproba = seq_logproba.to("cpu") + # result = result.to("cpu") + # correct = correct.to("cpu") + # seq_logprobas = seq_logprobas.to("cpu") - return result, correct, seq_logproba + return result, correct, seq_logprobas ###################################################################### @@ -226,6 +226,7 @@ class QuizMachine: result[i], correct[i], _ = self.predict( model=model, quizzes=input[i], struct=struct, mask=mask_generate ) + predicted_parts[i] = torch.tensor(mask_generate, device=self.device)[ None, : ] @@ -288,7 +289,7 @@ class QuizMachine: def models_logprobas( self, - models_for_validation, + model, c_quizzes, struct, mask_loss, @@ -300,9 +301,8 @@ class QuizMachine: c_quizzes = self.problem.reconfigure(c_quizzes, struct) - seq_logproba = torch.zeros( + seq_logprobas = torch.zeros( c_quizzes.size(0), - max([m.id for m in models_for_validation]) + 1, device=device, ) @@ -311,35 +311,32 @@ class QuizMachine: # c_quizzes, self.prompt_noise, struct=struct, mask=mask_noise # ) - for model in models_for_validation: - with torch.autograd.no_grad(): - t = model.training - model.eval() - - for input, l in zip( - c_quizzes.split(self.batch_size), - seq_logproba.split(self.batch_size), - ): - input = input.to(device) - quiz_mask_loss = self.make_quiz_mask( - input, struct=struct, mask=mask_loss - ) - output = model(mygpt.BracketedSequence(input)).x - l[:, model.id] = ( - -F.cross_entropy( - output.transpose(1, 2), input, reduction="none" - ) - * quiz_mask_loss - ).sum(dim=1) + with torch.autograd.no_grad(): + t = model.training + model.eval() + + for input, l in zip( + c_quizzes.split(self.batch_size), + seq_logprobas.split(self.batch_size), + ): + input = input.to(device) + quiz_mask_loss = self.make_quiz_mask( + input, struct=struct, mask=mask_loss + ) + output = model(mygpt.BracketedSequence(input)).x + l[...] = ( + -F.cross_entropy(output.transpose(1, 2), input, reduction="none") + * quiz_mask_loss + ).sum(dim=1) - model.train(t) + model.train(t) - return seq_logproba.to("cpu") + return seq_logprobas.to("cpu") ###################################################################### def generate_c_quizzes(self, nb, model_for_generation, procedure, recorder=None): - seq_logproba = torch.zeros(nb, device=self.device) + seq_logprobas = torch.zeros(nb, device=self.device) c_quizzes = None @@ -358,7 +355,7 @@ class QuizMachine: model=model_for_generation, input=c_quizzes, ar_mask=self.make_quiz_mask(c_quizzes, s, m), - seq_logproba=seq_logproba, + seq_logprobas=seq_logprobas, ) model_for_generation.reset_transformations() -- 2.39.5