3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import threading, queue, torch, tqdm
12 def nb_token_values(self):
15 def trivial_prompts_and_answers(self, prompts, answers):
18 # returns two tensors nb x D and nb x D'
19 def generate_prompts_and_answers(self, nb):
22 # save a file to vizualize quizzes, you can save a txt or png file
29 predicted_prompts=None,
30 predicted_answers=None,
35 class MultiThreadProblem:
36 def __init__(self, problem, max_nb_cached_chunks, chunk_size, nb_threads=1):
37 self.problem = problem
38 self.chunk_size = chunk_size
39 self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
40 for _ in range(nb_threads):
41 threading.Thread(target=self.fill_cache, daemon=True).start()
44 def nb_token_values(self):
45 return self.problem.nb_token_values()
53 predicted_prompts=None,
54 predicted_answers=None,
56 self.problem.save_quizzes(
61 predicted_prompts=None,
62 predicted_answers=None,
67 prompts, answers = self.problem.generate_prompts_and_answers(
71 self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
73 def trivial_prompts_and_answers(self, prompts, answers):
74 return self.problem.trivial_prompts_and_answers(prompts, answers)
76 def generate_prompts_and_answers(self, nb):
77 if self.rest is not None:
78 prompts, answers = rest
80 prompts, answers = [], []
84 n = sum([p.size(0) for p in prompts])
89 desc="world generation",
92 p, s = self.queue.get(block=True)
96 pbar.update(p.size(0))
98 prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
99 assert n == prompts.size(0)
104 rest = (prompts[-k:], answers[-k:])
105 prompts, answers = prompts[:-k], answers[:-k]
107 return prompts, answers