def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device):
# train
- one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True)
- one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=False)
+ one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=True)
+ one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=False)
# predict
quizzes = quiz_machine.quiz_set(150, c_quizzes, args.c_quiz_multiplier)
input, targets, mask = batch_prediction(quizzes.to(local_device))
- result = predict(model, input, targets, mask).to("cpu")
+ result = predict(model, input, targets, mask, local_device=local_device).to("cpu")
mask = mask.to("cpu")
correct = (quizzes == result).min(dim=1).values.long()
correct_parts = (2 * correct - 1)[:, None] * mask.reshape(mask.size(0), 4, -1)[
# generate
- result = generate(model, 25).to("cpu")
+ result = generate(model, 25, local_device=local_device).to("cpu")
quiz_machine.problem.save_quizzes_as_image(
args.result_dir,
f"culture_generation_{n_epoch}_{model.id}.png",