if c_quizzes.size(0) > 0:
logproba = c_quizzes.new(c_quizzes.size(0), len(models))
for q, l in zip(
- c_quizzes.split(args.batch_size), logits.split(args.batch_size)
+ c_quizzes.split(args.batch_size), logproba.split(args.batch_size)
):
for model in models:
l[model.id] = F.cross_entropy(model(q))