# 2->proba>=proba_understands and 1 otherwise.
-def generate_c_quizz_with_generator(generator, quiz_machine, nb):
+def generate_c_quizzes_with_generator(generator, quiz_machine, nb):
generator.to(main_device)
struct = ("A", "f_A", "B", "f_B")
num_classes=args.nb_gpts,
)
- prolog_c_quizzes = token_prolog_0 * i + token_prolog_2 * (1 - i)
- prolog_ar_mask = ar_mask.new_zeros(ar_mask.size(0), prolog_c_quizzes.size(1))
+ prologs_c_quizzes = token_prolog_0 * i + token_prolog_2 * (1 - i)
+ prologs_ar_mask = ar_mask.new_zeros(ar_mask.size(0), prologs_c_quizzes.size(1))
- prologued_c_quizzes = torch.cat([prolog_c_quizzes, c_quizzes], dim=1).to(
+ prologued_c_quizzes = torch.cat([prologs_c_quizzes, c_quizzes], dim=1).to(
main_device
)
- prologued_ar_mask = torch.cat([prolog_ar_mask, ar_mask], dim=1).to(main_device)
+ prologued_ar_mask = torch.cat([prologs_ar_mask, ar_mask], dim=1).to(main_device)
seq_logproba = torch.zeros(
prologued_c_quizzes.size(0), device=prologued_c_quizzes.device
prologued_c_quizzes * (prologued_c_quizzes < vocabulary_size).long()
)
- return prologued_c_quizzes[:, prolog_c_quizzes.size(1) :].to("cpu")
+ c_quizzes = prologued_c_quizzes[:, prologs_c_quizzes.size(1) :]
+
+ return c_quizzes.to("cpu"), prologs_c_quizzes.to("cpu")
def batches_for_generator(generator, quiz_machine, models, fraction_w_quizzes=1.0):
)
else:
# Or we use the generator itself to generate them
- c_quizzes = generate_c_quizz_with_generator(
+ c_quizzes, _ = generate_c_quizzes_with_generator(
generator, quiz_machine, args.batch_size
)
u2 = probas >= args.proba_understands
u1 = (u0 | u2) == False
- prolog = (
+ prologs = (
(u0.long() * token_prolog_0)
+ (u1.long() * token_prolog_1)
+ (u2.long() * token_prolog_2)
)
- prologued_c_quizzes = torch.cat([prolog, c_quizzes], dim=1)
+ prologued_c_quizzes = torch.cat([prologs, c_quizzes], dim=1)
# nb_u2 = u2.long().sum(dim=1)
# nb_u0 = u0.long().sum(dim=1)
######################################################################
if args.test_generator:
- filename = f"generator.pth"
-
- try:
- d = torch.load(os.path.join(args.result_dir, filename))
- generator.load_state_dict(d[0])
- generator.main_test_accuracy = d[1]
- log_string(f"successfully loaded {filename}")
- except FileNotFoundError:
- log_string(f"cannot find {filename}")
- pass
-
token_prolog_0 = vocabulary_size + 0
token_prolog_1 = vocabulary_size + 1
token_prolog_2 = vocabulary_size + 2
generator.main_test_accuracy = 0.0
+ filename = f"generator.pth"
+
+ try:
+ d = torch.load(os.path.join(args.result_dir, filename))
+ generator.load_state_dict(d[0])
+ generator.main_test_accuracy = d[1]
+ log_string(f"successfully loaded {filename}")
+ except FileNotFoundError:
+ log_string(f"cannot find {filename}")
+ pass
+
for n_epoch in range(args.nb_epochs):
one_generator_epoch(
generator,
)
log_string(f"wrote {filename}")
- c_quizzes = generate_c_quizz_with_generator(
+ c_quizzes, prologs = generate_c_quizzes_with_generator(
generator, quiz_machine, args.batch_size
)
models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
)
- print(seq_logproba.exp())
+ probas = seq_logproba.exp()
+
+ u0 = probas <= args.proba_not_understands
+ u2 = probas >= args.proba_understands
+ u1 = (u0 | u2) == False
+
+ predicted_prologs = (
+ (u0.long() * token_prolog_0)
+ + (u1.long() * token_prolog_1)
+ + (u2.long() * token_prolog_2)
+ )
comments = []
- for l in seq_logproba:
- comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l]))
+ nb_errors = (predicted_prologs != prologs).long().sum()
+ nb_total = prologs.numel()
+
+ log_string(f"generator_error {nb_errors} / {nb_total}")
+
+ def readable(prologs):
+ return (prologs == token_prolog_1) + 2 * (prologs == token_prolog_2)
+
+ for aa, ee, ff in zip(probas, readable(predicted_prologs), readable(prologs)):
+ sa = "prolog " + " ".join(
+ [f"{e.item()}/{f.item()}" for e, f in zip(ee, ff)]
+ )
+ sp = "proba " + " ".join([f"{p.item():.02f}" for p in aa])
+ comments.append(sa + "\n" + sp)
filename = f"generator_batch_{n_epoch:04d}.png"
quiz_machine.problem.save_quizzes_as_image(