X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=problem.py;h=05f3b20984165acab83fd3ae563e6be896a14d6a;hb=refs%2Fheads%2Fmaster;hp=617b2a86d01e27233bf5ded070fbb85fb74e7d43;hpb=1e7e3d3877ae98737852b78f35a98030dcc0701f;p=culture.git diff --git a/problem.py b/problem.py index 617b2a8..9bee5b2 100755 --- a/problem.py +++ b/problem.py @@ -25,18 +25,63 @@ class Problem: else: return self.queue.qsize() * self.chunk_size - def nb_token_values(self): - pass + def fill_cache(self): + while True: + quizzes = self.generate_w_quizzes_(self.chunk_size) + self.queue.put(quizzes.to("cpu"), block=True) + + def generate_w_quizzes(self, nb, progress_bar=True): + if self.queue is None: + return self.generate_w_quizzes_(nb) + + if self.rest is not None: + quizzes = rest + else: + quizzes = [] + + self.rest = None + + n = sum([q.size(0) for q in quizzes]) + + if progress_bar: + with tqdm.tqdm( + total=nb, + dynamic_ncols=True, + desc="world generation", + ) as pbar: + while n < nb: + q = self.queue.get(block=True) + quizzes.append(q) + n += q.size(0) + pbar.update(q.size(0)) + else: + while n < nb: + q = self.queue.get(block=True) + quizzes.append(q) + n += q.size(0) + + quizzes = torch.cat(quizzes, dim=0) + assert n == quizzes.size(0) + + k = n - nb + + if k > 0: + rest = quizzes[-k:] + quizzes = quizzes[:-k] + + return quizzes + + ###################################################################### def trivial_prompts_and_answers(self, prompts, answers): pass # The one to implement, returns two tensors nb x D and nb x D' - def generate_prompts_and_answers_(self, nb): + def generate_w_quizzes_(self, nb): pass # save a file to vizualize quizzes, you can save a txt or png file - def save_quizzes( + def save_quiz_illustrations( self, result_dir, filename_prefix, @@ -47,44 +92,7 @@ class Problem: ): pass - def fill_cache(self): - while True: - prompts, answers = self.generate_prompts_and_answers_(self.chunk_size) - - self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True) - - def generate_prompts_and_answers(self, nb): - if self.queue is None: - return self.generate_prompts_and_answers_(nb) - - if self.rest is not None: - prompts, answers = rest - else: - prompts, answers = [], [] - - self.rest = None - - n = sum([p.size(0) for p in prompts]) - - with tqdm.tqdm( - total=nb, - dynamic_ncols=True, - desc="world generation", - ) as pbar: - while n < nb: - p, s = self.queue.get(block=True) - prompts.append(p) - answers.append(s) - n += p.size(0) - pbar.update(p.size(0)) - - prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0) - assert n == prompts.size(0) - - k = n - nb - - if k > 0: - rest = (prompts[-k:], answers[-k:]) - prompts, answers = prompts[:-k], answers[:-k] + def save_some_examples(self, result_dir): + pass - return prompts, answers + ######################################################################