# Written by Francois Fleuret <francois@fleuret.org>
-import threading, queue, torch
+import threading, queue, torch, tqdm
class Problem:
+ def __init__(self, max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1):
+ if nb_threads > 0:
+ self.chunk_size = chunk_size
+ self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
+ for _ in range(nb_threads):
+ threading.Thread(target=self.fill_cache, daemon=True).start()
+ self.rest = None
+ else:
+ self.queue = None
+
+ def nb_cached_quizzes(self):
+ if self.queue is None:
+ return None
+ else:
+ return self.queue.qsize() * self.chunk_size
+
def nb_token_values(self):
pass
def trivial_prompts_and_answers(self, prompts, answers):
pass
- # returns two tensors nb x D and nb x D'
- def generate_prompts_and_answers(self, nb):
+ # The one to implement, returns two tensors nb x D and nb x D'
+ def generate_prompts_and_answers_(self, nb):
pass
# save a file to vizualize quizzes, you can save a txt or png file
):
pass
-
-class MultiThreadProblem:
- def __init__(self, problem, max_nb_cached_chunks, chunk_size):
- self.problem = problem
- self.chunk_size = chunk_size
- self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
- threading.Thread(target=self.fill_cache, daemon=True).start()
- self.rest = None
-
- def nb_token_values(self):
- return self.problem.nb_token_values()
-
- def save_quizzes(
- self,
- result_dir,
- filename_prefix,
- prompts,
- answers,
- predicted_prompts=None,
- predicted_answers=None,
- ):
- self.problem.save_quizzes(
- result_dir,
- filename_prefix,
- prompts,
- answers,
- predicted_prompts=None,
- predicted_answers=None,
- )
-
def fill_cache(self):
while True:
- prompts, answers = self.problem.generate_prompts_and_answers(
- self.chunk_size
- )
+ prompts, answers = self.generate_prompts_and_answers_(self.chunk_size)
- self.queue.put((prompts, answers), block=True)
-
- def trivial_prompts_and_answers(self, prompts, answers):
- return self.problem.trivial_prompts_and_answers(prompts, answers)
+ self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
def generate_prompts_and_answers(self, nb):
+ if self.queue is None:
+ return self.generate_prompts_and_answers_(nb)
+
if self.rest is not None:
prompts, answers = rest
else:
n = sum([p.size(0) for p in prompts])
- while n < nb:
- p, s = self.queue.get(block=True)
- prompts.append(p)
- answers.append(s)
- n += p.size(0)
+ with tqdm.tqdm(
+ total=nb,
+ dynamic_ncols=True,
+ desc="world generation",
+ ) as pbar:
+ while n < nb:
+ p, s = self.queue.get(block=True)
+ prompts.append(p)
+ answers.append(s)
+ n += p.size(0)
+ pbar.update(p.size(0))
prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
+ assert n == prompts.size(0)
k = n - nb