def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
nb_validated, nb_to_validate = 0, (nb_for_train + nb_for_test) * len(models)
- nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate
+ nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate // 10
start_time = time.perf_counter()
if args.test == "func":
- train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples)
test_input = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples)
- L = train_input.size(1) // 4
- f_len = 25
+ L = test_input.size(1) // 4
+ f_len = 50
model = Thinker(
vocabulary_size=vocabulary_size,
dim_hidden=args.dim_hidden,
nb_heads=args.nb_heads,
nb_blocks=args.nb_blocks,
- f_len=20,
+ f_len=f_len,
dropout=args.dropout,
).to(main_device)
for n_epoch in range(args.nb_epochs):
model.train()
+ train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples)
+
nb_train_samples, acc_train_loss = 0, 0.0
for input in tqdm.tqdm(