X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=problem.py;h=05f3b20984165acab83fd3ae563e6be896a14d6a;hb=5f5c6c079c2751a76887444c211c5c464e875ed0;hp=0bc83a12148c5fec2af3ed7b96edfce3b86054d5;hpb=3b41e2797fc340fd11cb35015b57c3cae1e8447b;p=culture.git diff --git a/problem.py b/problem.py index 0bc83a1..05f3b20 100755 --- a/problem.py +++ b/problem.py @@ -5,23 +5,89 @@ # Written by Francois Fleuret +import threading, queue, torch, tqdm + class Problem: + def __init__(self, max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1): + if nb_threads > 0: + self.chunk_size = chunk_size + self.queue = queue.Queue(maxsize=max_nb_cached_chunks) + for _ in range(nb_threads): + threading.Thread(target=self.fill_cache, daemon=True).start() + self.rest = None + else: + self.queue = None + + def nb_cached_quizzes(self): + if self.queue is None: + return None + else: + return self.queue.qsize() * self.chunk_size + def nb_token_values(self): pass - # returns two tensors nb x D and nb x D' - def generate_prompts_and_answers(self, nb): + def trivial_prompts_and_answers(self, prompts, answers): + pass + + # The one to implement, returns two tensors nb x D and nb x D' + def generate_prompts_and_answers_(self, nb): pass # save a file to vizualize quizzes, you can save a txt or png file - def save_quizzes( + def save_quiz_illustrations( self, result_dir, filename_prefix, prompts, answers, - predicted_prompt=None, + predicted_prompts=None, predicted_answers=None, ): pass + + def fill_cache(self): + while True: + prompts, answers = self.generate_prompts_and_answers_(self.chunk_size) + + self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True) + + def generate_prompts_and_answers(self, nb): + if self.queue is None: + return self.generate_prompts_and_answers_(nb) + + if self.rest is not None: + prompts, answers = rest + else: + prompts, answers = [], [] + + self.rest = None + + n = sum([p.size(0) for p in prompts]) + + with tqdm.tqdm( + total=nb, + dynamic_ncols=True, + desc="world generation", + ) as pbar: + while n < nb: + p, s = self.queue.get(block=True) + prompts.append(p) + answers.append(s) + n += p.size(0) + pbar.update(p.size(0)) + + prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0) + assert n == prompts.size(0) + + k = n - nb + + if k > 0: + rest = (prompts[-k:], answers[-k:]) + prompts, answers = prompts[:-k], answers[:-k] + + return prompts, answers + + def save_some_examples(self, result_dir): + pass