X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=problem.py;h=7eeb6b40113ed93a70c3c8267e7fa41bd3a1acdf;hb=dc28a8c5cc4456cd9512c57947b104ac23338c8f;hp=a49634d1fa9872167b46bed32c7d8e6803c9ac0e;hpb=9ec709a2a08eb82dfc17ef1e24aa9a84751d63e0;p=culture.git diff --git a/problem.py b/problem.py index a49634d..7eeb6b4 100755 --- a/problem.py +++ b/problem.py @@ -9,14 +9,30 @@ import threading, queue, torch, tqdm class Problem: + def __init__(self, max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1): + if nb_threads > 0: + self.chunk_size = chunk_size + self.queue = queue.Queue(maxsize=max_nb_cached_chunks) + for _ in range(nb_threads): + threading.Thread(target=self.fill_cache, daemon=True).start() + self.rest = None + else: + self.queue = None + + def nb_cached_quizzes(self): + if self.queue is None: + return None + else: + return self.queue.qsize() * self.chunk_size + def nb_token_values(self): pass def trivial_prompts_and_answers(self, prompts, answers): pass - # returns two tensors nb x D and nb x D' - def generate_prompts_and_answers(self, nb): + # The one to implement, returns two tensors nb x D and nb x D' + def generate_prompts_and_answers_(self, nb): pass # save a file to vizualize quizzes, you can save a txt or png file @@ -31,49 +47,16 @@ class Problem: ): pass - -class MultiThreadProblem: - def __init__(self, problem, max_nb_cached_chunks, chunk_size, nb_threads=1): - self.problem = problem - self.chunk_size = chunk_size - self.queue = queue.Queue(maxsize=max_nb_cached_chunks) - for _ in range(nb_threads): - threading.Thread(target=self.fill_cache, daemon=True).start() - self.rest = None - - def nb_token_values(self): - return self.problem.nb_token_values() - - def save_quizzes( - self, - result_dir, - filename_prefix, - prompts, - answers, - predicted_prompts=None, - predicted_answers=None, - ): - self.problem.save_quizzes( - result_dir, - filename_prefix, - prompts, - answers, - predicted_prompts=None, - predicted_answers=None, - ) - def fill_cache(self): while True: - prompts, answers = self.problem.generate_prompts_and_answers( - self.chunk_size - ) + prompts, answers = self.generate_prompts_and_answers_(self.chunk_size) self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True) - def trivial_prompts_and_answers(self, prompts, answers): - return self.problem.trivial_prompts_and_answers(prompts, answers) - def generate_prompts_and_answers(self, nb): + if self.queue is None: + return self.generate_prompts_and_answers_(nb) + if self.rest is not None: prompts, answers = rest else: @@ -105,3 +88,6 @@ class MultiThreadProblem: prompts, answers = prompts[:-k], answers[:-k] return prompts, answers + + def save_some_examples(self, result_dir): + pass