+ generator.train(t)
+
+ generator.reset_transformations()
+
+ prologued_c_quizzes = (
+ prologued_c_quizzes * (prologued_c_quizzes < vocabulary_size).long()
+ )
+
+ c_quizzes = prologued_c_quizzes[:, prologs_c_quizzes.size(1) :]
+
+ return c_quizzes.to("cpu"), prologs_c_quizzes.to("cpu")
+
+
+def batches_for_generator(generator, quiz_machine, models, fraction_w_quizzes=1.0):
+ samples = []
+
+ for _ in range(args.nb_train_samples // args.batch_size):
+ while sum([x.size(0) for x in samples]) < args.batch_size:
+ # Generate a bunch of quizzes
+
+ if torch.rand(1).item() <= fraction_w_quizzes:
+ # Either we start with the world quizzes
+ c_quizzes = quiz_machine.problem.generate_w_quizzes(
+ args.batch_size, progress_bar=False
+ )
+ else:
+ # Or we use the generator itself to generate them
+ c_quizzes, _ = generate_c_quizzes_with_generator(
+ generator, quiz_machine, args.batch_size
+ )
+
+ # We remove the trivial ones
+ to_keep = quiz_machine.problem.trivial(c_quizzes) == False
+ c_quizzes = c_quizzes[to_keep]
+
+ # If there are remaining ones, we compute the true prolog
+ # that indicates how the GPTs solve it
+
+ if c_quizzes.size(0) > 0:
+ seq_logproba = quiz_machine.models_logprobas(
+ models,
+ c_quizzes,
+ ("A", "f_A", "B", "f_B"),
+ (0, 0, 0, 1),
+ (0, 0, 1, 0),
+ ) + quiz_machine.models_logprobas(
+ models,
+ c_quizzes,
+ ("f_A", "A", "f_B", "B"),
+ (0, 0, 0, 1),
+ (0, 0, 1, 0),
+ )
+
+ probas = seq_logproba.exp()
+
+ u0 = probas <= args.proba_not_understands
+ u2 = probas >= args.proba_understands
+ u1 = (u0 | u2) == False
+
+ prologs = (
+ (u0.long() * token_prolog_0)
+ + (u1.long() * token_prolog_1)
+ + (u2.long() * token_prolog_2)
+ )
+
+ prologued_c_quizzes = torch.cat([prologs, c_quizzes], dim=1)
+
+ # nb_u2 = u2.long().sum(dim=1)
+ # nb_u0 = u0.long().sum(dim=1)
+ # prologued_c_quizzes = prologued_c_quizzes[(nb_u2 >= 1) & (nb_u0 >= 1)]
+
+ if prologued_c_quizzes.size(0) > 0:
+ samples.append(prologued_c_quizzes)
+
+ # Now we yield a batch
+
+ x = torch.cat(samples, dim=0)
+ samples = [x[args.batch_size :]]
+
+ yield x[: args.batch_size]
+
+
+def one_generator_epoch(
+ generator, quiz_machine, models, fraction_w_quizzes, local_device=main_device
+):
+ model.to(local_device).train()
+
+ optimizer = torch.optim.Adam(generator.parameters(), lr=args.learning_rate)
+
+ nb_train_samples, acc_train_loss = 0, 0.0
+
+ src = batches_for_generator(
+ generator=generator,
+ quiz_machine=quiz_machine,
+ models=models,
+ fraction_w_quizzes=fraction_w_quizzes,
+ )
+
+ for input in tqdm.tqdm(
+ src,
+ dynamic_ncols=True,
+ desc="training",
+ total=args.nb_train_samples // args.batch_size,
+ ):
+ input = input.to(local_device)
+
+ if nb_train_samples % args.batch_size == 0:
+ optimizer.zero_grad()
+
+ targets = input
+
+ output = generator(mygpt.BracketedSequence(input)).x
+ loss = F.cross_entropy(output.transpose(1, 2), targets)
+ acc_train_loss += loss.item() * input.size(0)
+ nb_train_samples += input.size(0)
+
+ loss.backward()
+
+ if nb_train_samples % args.batch_size == 0:
+ optimizer.step()
+
+ train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+
+ log_string(f"train_perplexity {n_epoch} generator - {train_perplexity}")
+
+ generator.to(main_device)
+
+
+######################################################################
+
+
+def train_complexifier(model_gen, model_pred1, model_pred2):
+ samples = []
+ perf = []
+
+ optimizer = torch.optim.Adam(model_gen.parameters(), lr=args.learning_rate)
+
+ nb_train_samples, acc_train_loss = 0, 0.0
+
+ for n_epoch in range(args.nb_epochs):
+ for b in range(args.nb_train_samples // args.batch_size):
+ while sum([x.size(0) for x in samples]) < args.batch_size:
+ c_quizzes = quiz_machine.generate_c_quizzes(
+ args.inference_batch_size,
+ model_for_generation=model_gen,
+ procedure=c_quizzes_procedure,
+ )
+ to_keep = quiz_machine.problem.trivial(c_quizzes) == False
+ c_quizzes = c_quizzes[to_keep]
+ if c_quizzes.size(0) > 0:
+ seq_logproba = quiz_machine.models_logprobas(
+ [model_pred1, model_pred2],
+ c_quizzes,
+ ("A", "f_A", "B", "f_B"),
+ (0, 0, 0, 1),
+ ) + quiz_machine.models_logprobas(
+ [model_pred1, model_pred2],
+ c_quizzes,
+ ("f_A", "A", "f_B", "B"),
+ (0, 0, 0, 1),
+ )
+ probas = seq_logproba.exp()
+ to_keep = (probas[:, model_pred1.id] >= args.proba_understands) & (
+ probas[:, model_pred2.id] <= args.proba_not_understands
+ )
+ log_string(
+ f"generating {to_keep.long().sum()} / {c_quizzes.size(0)}"
+ )
+ c_quizzes = c_quizzes[to_keep]
+ if c_quizzes.size(0):
+ samples.append(c_quizzes)
+
+ log_string(f"full batch {sum([x.size(0) for x in samples])}")
+
+ x = torch.cat(samples, dim=0)
+
+ input = x[: args.batch_size]
+ samples = [x[args.batch_size :]]
+
+ # -------------------
+
+ seq_logproba = quiz_machine.models_logprobas(
+ [model_pred1, model_pred2],
+ input,
+ ("A", "f_A", "B", "f_B"),
+ (0, 0, 0, 1),
+ ) + quiz_machine.models_logprobas(
+ [model_pred1, model_pred2],
+ input,
+ ("f_A", "A", "f_B", "B"),
+ (0, 0, 0, 1),
+ )
+
+ comments = []
+
+ for l in seq_logproba:
+ comments.append(
+ f"proba {l[model_pred1.id].exp().item():.02f} {l[model_pred2.id].exp().item():.02f}"
+ )
+
+ filename = f"batch_{n_epoch:04d}_{b:04d}.png"
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir, filename, input, comments=comments
+ )
+ log_string(f"wrote {filename}")
+
+ # ------------------------
+
+ input = input.to(main_device)
+
+ if nb_train_samples % args.batch_size == 0:
+ optimizer.zero_grad()
+
+ output = model_gen(mygpt.BracketedSequence(input)).x
+ loss = F.cross_entropy(output.transpose(1, 2), input)
+ acc_train_loss += loss.item() * input.size(0)
+ nb_train_samples += input.size(0)
+
+ loss.backward()
+
+ if nb_train_samples % args.batch_size == 0:
+ optimizer.step()
+
+ train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+
+ log_string(f"train_perplexity {n_epoch} model ae {train_perplexity}")