def generate_c_quizz_with_generator(generator, quiz_machine, nb):
generator.to(main_device)
- c_quizzes = quiz_machine.problem.create_empty_quizzes(
- nb, struct=("A", "f_A", "B", "f_B")
- )
+ struct = ("A", "f_A", "B", "f_B")
+
+ c_quizzes = quiz_machine.problem.create_empty_quizzes(nb, struct=struct)
+ ar_mask = quiz_machine.make_ar_mask(c_quizzes, struct, (1, 1, 1, 1))
i = F.one_hot(
torch.randint(args.nb_gpts, (c_quizzes.size(0),)),
num_classes=args.nb_gpts,
)
- prolog = token_prolog_0 * i + token_prolog_2 * (1 - i)
- len_prolog, len_quiz = prolog.size(1), c_quizzes.size(1)
-
- prologued_c_quizzes = torch.cat([prolog, c_quizzes], dim=1).to(main_device)
-
- T = torch.arange(prologued_c_quizzes.size(1), device=prologued_c_quizzes.device)[
- None, :
- ]
+ prolog_c_quizzes = token_prolog_0 * i + token_prolog_2 * (1 - i)
+ prolog_ar_mask = ar_mask.new_zeros(ar_mask.size(0), prolog_c_quizzes.size(1))
- ar_mask = ((T >= len_prolog) & ((T - len_prolog) % (len_quiz // 4) > 0)).long()
+ prologued_c_quizzes = torch.cat([prolog_c_quizzes, c_quizzes], dim=1).to(
+ main_device
+ )
+ prologued_ar_mask = torch.cat([prolog_ar_mask, ar_mask], dim=1).to(main_device)
seq_logproba = torch.zeros(
prologued_c_quizzes.size(0), device=prologued_c_quizzes.device
)
+ generator.temperature = args.temperature_hot
+
with torch.autograd.no_grad():
t = generator.training
generator.eval()
one_batch_masked_inplace_autoregression(
generator,
prologued_c_quizzes,
- ar_mask,
+ prologued_ar_mask,
seq_logproba,
deterministic_synthesis=False,
)
generator.train(t)
+ generator.reset_transformations()
+
prologued_c_quizzes = (
prologued_c_quizzes * (prologued_c_quizzes < vocabulary_size).long()
)
- return prologued_c_quizzes[:, len_prolog:].to("cpu")
+ return prologued_c_quizzes[:, prolog_c_quizzes.size(1) :].to("cpu")
-def batches_for_generator(generator, quiz_machine, models, w_quizzes=True):
+def batches_for_generator(generator, quiz_machine, models, fraction_w_quizzes=1.0):
samples = []
for _ in range(args.nb_train_samples // args.batch_size):
while sum([x.size(0) for x in samples]) < args.batch_size:
# Generate a bunch of quizzes
- if w_quizzes:
+ if torch.rand(1).item() <= fraction_w_quizzes:
# Either we start with the world quizzes
c_quizzes = quiz_machine.problem.generate_w_quizzes(
args.batch_size, progress_bar=False
else:
# Or we use the generator itself to generate them
c_quizzes = generate_c_quizz_with_generator(
- args.batch_size, generator, quiz_machine
+ generator, quiz_machine, args.batch_size
)
# We remove the trivial ones
probas = seq_logproba.exp()
- nu = probas <= args.proba_not_understands
- u = probas >= args.proba_understands
+ u0 = probas <= args.proba_not_understands
+ u2 = probas >= args.proba_understands
+ u1 = (u0 | u2) == False
prolog = (
- (nu.long() * token_prolog_0)
- + (((nu == False) & (u == False)).long() * token_prolog_1)
- + (u.long() * token_prolog_2)
+ (u0.long() * token_prolog_0)
+ + (u1.long() * token_prolog_1)
+ + (u2.long() * token_prolog_2)
)
prologued_c_quizzes = torch.cat([prolog, c_quizzes], dim=1)
- # nb_u = u.long().sum(dim=1)
- # nb_nu = nu.long().sum(dim=1)
-
- # prologued_c_quizzes = prologued_c_quizzes[
- # (nb_u + nb_nu == args.nb_gpts)
- # & (nb_nu >= 1)
- # & (nb_nu <= args.max_fail_to_validate)
- # ]
+ # nb_u2 = u2.long().sum(dim=1)
+ # nb_u0 = u0.long().sum(dim=1)
+ # prologued_c_quizzes = prologued_c_quizzes[(nb_u2 >= 1) & (nb_u0 >= 1)]
- samples.append(prologued_c_quizzes)
+ if prologued_c_quizzes.size(0) > 0:
+ samples.append(prologued_c_quizzes)
# Now we yield a batch
def one_generator_epoch(
- generator, quiz_machine, models, w_quizzes=True, local_device=main_device
+ generator, quiz_machine, models, fraction_w_quizzes, local_device=main_device
):
model.to(local_device).train()
nb_train_samples, acc_train_loss = 0, 0.0
- hard_w_quizzes = []
-
src = batches_for_generator(
- generator=generator, quiz_machine=quiz_machine, models=models
+ generator=generator,
+ quiz_machine=quiz_machine,
+ models=models,
+ fraction_w_quizzes=fraction_w_quizzes,
)
for input in tqdm.tqdm(
generator,
quiz_machine=quiz_machine,
models=models,
- w_quizzes=True,
+ fraction_w_quizzes=1 if n_epoch < 25 else 0.5,
local_device=main_device,
)
)
log_string(f"wrote {filename}")
- one_generator_epoch(
- generator,
- quiz_machine=quiz_machine,
- models=models,
- w_quizzes=False,
- local_device=main_device,
- )
-
exit(0)
######################################################################