3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import threading, queue, torch, tqdm
12 def __init__(self, max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1):
14 self.chunk_size = chunk_size
15 self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
16 for _ in range(nb_threads):
17 threading.Thread(target=self.fill_cache, daemon=True).start()
22 def nb_token_values(self):
25 def trivial_prompts_and_answers(self, prompts, answers):
28 # The one to implement, returns two tensors nb x D and nb x D'
29 def generate_prompts_and_answers_(self, nb):
32 # save a file to vizualize quizzes, you can save a txt or png file
39 predicted_prompts=None,
40 predicted_answers=None,
46 prompts, answers = self.generate_prompts_and_answers_(self.chunk_size)
48 self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
50 def generate_prompts_and_answers(self, nb):
51 if self.queue is None:
52 return self.generate_prompts_and_answers_(nb)
54 if self.rest is not None:
55 prompts, answers = rest
57 prompts, answers = [], []
61 n = sum([p.size(0) for p in prompts])
66 desc="world generation",
69 p, s = self.queue.get(block=True)
73 pbar.update(p.size(0))
75 prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
76 assert n == prompts.size(0)
81 rest = (prompts[-k:], answers[-k:])
82 prompts, answers = prompts[:-k], answers[:-k]
84 return prompts, answers