+
+
+class MultiThreadProblem:
+ def __init__(self, problem, max_nb_cached_chunks, chunk_size):
+ self.problem = problem
+ self.chunk_size = chunk_size
+ self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
+ threading.Thread(target=self.fill_cache, daemon=True).start()
+ self.rest = None
+
+ def nb_token_values(self):
+ return self.problem.nb_token_values()
+
+ def save_quizzes(
+ self,
+ result_dir,
+ filename_prefix,
+ prompts,
+ answers,
+ predicted_prompts=None,
+ predicted_answers=None,
+ ):
+ self.problem.save_quizzes(
+ result_dir,
+ filename_prefix,
+ prompts,
+ answers,
+ predicted_prompts=None,
+ predicted_answers=None,
+ )
+
+ def fill_cache(self):
+ while True:
+ prompts, answers = self.problem.generate_prompts_and_answers(
+ self.chunk_size
+ )
+
+ self.queue.put((prompts, answers), block=True)
+
+ def trivial_prompts_and_answers(self, prompts, answers):
+ return self.problem.trivial_prompts_and_answers(prompts, answers)
+
+ def generate_prompts_and_answers(self, nb):
+ if self.rest is not None:
+ prompts, answers = rest
+ else:
+ prompts, answers = [], []
+
+ self.rest = None
+
+ n = sum([p.size(0) for p in prompts])
+
+ while n < nb:
+ p, s = self.queue.get(block=True)
+ prompts.append(p)
+ answers.append(s)
+ n += p.size(0)
+
+ prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
+
+ k = n - nb
+
+ if k > 0:
+ rest = (prompts[-k:], answers[-k:])
+ prompts, answers = prompts[:-k], answers[:-k]
+
+ return prompts, answers