3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import threading, queue, torch, tqdm
12 def __init__(self, max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1):
14 self.chunk_size = chunk_size
15 self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
16 for _ in range(nb_threads):
17 threading.Thread(target=self.fill_cache, daemon=True).start()
22 def nb_cached_quizzes(self):
23 if self.queue is None:
26 return self.queue.qsize() * self.chunk_size
28 def nb_token_values(self):
31 def trivial_prompts_and_answers(self, prompts, answers):
34 # The one to implement, returns two tensors nb x D and nb x D'
35 def generate_prompts_and_answers_(self, nb):
38 # save a file to vizualize quizzes, you can save a txt or png file
45 predicted_prompts=None,
46 predicted_answers=None,
52 prompts, answers = self.generate_prompts_and_answers_(self.chunk_size)
54 self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
56 def generate_prompts_and_answers(self, nb):
57 if self.queue is None:
58 return self.generate_prompts_and_answers_(nb)
60 if self.rest is not None:
61 prompts, answers = rest
63 prompts, answers = [], []
67 n = sum([p.size(0) for p in prompts])
72 desc="world generation",
75 p, s = self.queue.get(block=True)
79 pbar.update(p.size(0))
81 prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
82 assert n == prompts.size(0)
87 rest = (prompts[-k:], answers[-k:])
88 prompts, answers = prompts[:-k], answers[:-k]
90 return prompts, answers