3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import threading, queue, torch
12 def nb_token_values(self):
15 def trivial_prompts_and_answers(self, prompts, answers):
18 # returns two tensors nb x D and nb x D'
19 def generate_prompts_and_answers(self, nb):
22 # save a file to vizualize quizzes, you can save a txt or png file
29 predicted_prompts=None,
30 predicted_answers=None,
35 class MultiThreadProblem:
36 def __init__(self, problem, max_nb_cached_chunks, chunk_size):
37 self.problem = problem
38 self.chunk_size = chunk_size
39 self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
40 threading.Thread(target=self.fill_cache, daemon=True).start()
43 def nb_token_values(self):
44 return self.problem.nb_token_values()
52 predicted_prompts=None,
53 predicted_answers=None,
55 self.problem.save_quizzes(
60 predicted_prompts=None,
61 predicted_answers=None,
66 prompts, answers = self.problem.generate_prompts_and_answers(
70 self.queue.put((prompts, answers), block=True)
72 def trivial_prompts_and_answers(self, prompts, answers):
73 return self.problem.trivial_prompts_and_answers(prompts, answers)
75 def generate_prompts_and_answers(self, nb):
76 if self.rest is not None:
77 prompts, answers = rest
79 prompts, answers = [], []
83 n = sum([p.size(0) for p in prompts])
86 p, s = self.queue.get(block=True)
91 prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
96 rest = (prompts[-k:], answers[-k:])
97 prompts, answers = prompts[:-k], answers[:-k]
99 return prompts, answers