X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=problem.py;h=617b2a86d01e27233bf5ded070fbb85fb74e7d43;hb=1fe0529f4ac7cacfa8fced2885b7da1e639962a9;hp=7dd60dcd4dea941df4139209f95eae9d98464afe;hpb=93cea45f62046a3481d6c05ab2cfe70f6dbc93b3;p=culture.git diff --git a/problem.py b/problem.py index 7dd60dc..617b2a8 100755 --- a/problem.py +++ b/problem.py @@ -5,18 +5,34 @@ # Written by Francois Fleuret -import threading, queue, torch +import threading, queue, torch, tqdm class Problem: + def __init__(self, max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1): + if nb_threads > 0: + self.chunk_size = chunk_size + self.queue = queue.Queue(maxsize=max_nb_cached_chunks) + for _ in range(nb_threads): + threading.Thread(target=self.fill_cache, daemon=True).start() + self.rest = None + else: + self.queue = None + + def nb_cached_quizzes(self): + if self.queue is None: + return None + else: + return self.queue.qsize() * self.chunk_size + def nb_token_values(self): pass def trivial_prompts_and_answers(self, prompts, answers): pass - # returns two tensors nb x D and nb x D' - def generate_prompts_and_answers(self, nb): + # The one to implement, returns two tensors nb x D and nb x D' + def generate_prompts_and_answers_(self, nb): pass # save a file to vizualize quizzes, you can save a txt or png file @@ -31,48 +47,16 @@ class Problem: ): pass - -class MultiThreadProblem: - def __init__(self, problem, max_nb_cached_chunks, chunk_size): - self.problem = problem - self.chunk_size = chunk_size - self.queue = queue.Queue(maxsize=max_nb_cached_chunks) - threading.Thread(target=self.fill_cache, daemon=True).start() - self.rest = None - - def nb_token_values(self): - return self.problem.nb_token_values() - - def save_quizzes( - self, - result_dir, - filename_prefix, - prompts, - answers, - predicted_prompts=None, - predicted_answers=None, - ): - self.problem.save_quizzes( - result_dir, - filename_prefix, - prompts, - answers, - predicted_prompts=None, - predicted_answers=None, - ) - def fill_cache(self): while True: - prompts, answers = self.problem.generate_prompts_and_answers( - self.chunk_size - ) + prompts, answers = self.generate_prompts_and_answers_(self.chunk_size) - self.queue.put((prompts, answers), block=True) - - def trivial_prompts_and_answers(self, prompts, answers): - return self.problem.trivial_prompts_and_answers(prompts, answers) + self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True) def generate_prompts_and_answers(self, nb): + if self.queue is None: + return self.generate_prompts_and_answers_(nb) + if self.rest is not None: prompts, answers = rest else: @@ -82,13 +66,20 @@ class MultiThreadProblem: n = sum([p.size(0) for p in prompts]) - while n < nb: - p, s = self.queue.get(block=True) - prompts.append(p) - answers.append(s) - n += p.size(0) + with tqdm.tqdm( + total=nb, + dynamic_ncols=True, + desc="world generation", + ) as pbar: + while n < nb: + p, s = self.queue.get(block=True) + prompts.append(p) + answers.append(s) + n += p.size(0) + pbar.update(p.size(0)) prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0) + assert n == prompts.size(0) k = n - nb