X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=problem.py;h=617b2a86d01e27233bf5ded070fbb85fb74e7d43;hb=1fe0529f4ac7cacfa8fced2885b7da1e639962a9;hp=95a9c4180be0bbac89bd19b8329b80c22942a72f;hpb=4f0057b363762698f90eea05de154e62b6883bd0;p=culture.git diff --git a/problem.py b/problem.py index 95a9c41..617b2a8 100755 --- a/problem.py +++ b/problem.py @@ -5,17 +5,86 @@ # Written by Francois Fleuret +import threading, queue, torch, tqdm + class Problem: - # returns a nb x (L+1+L) long tensor where L is the length of one - # of the two states of a quizz - def generate_seq(self, nb): + def __init__(self, max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1): + if nb_threads > 0: + self.chunk_size = chunk_size + self.queue = queue.Queue(maxsize=max_nb_cached_chunks) + for _ in range(nb_threads): + threading.Thread(target=self.fill_cache, daemon=True).start() + self.rest = None + else: + self.queue = None + + def nb_cached_quizzes(self): + if self.queue is None: + return None + else: + return self.queue.qsize() * self.chunk_size + + def nb_token_values(self): pass - # save a file to vizualize quizzes, you can save a txt or png file - def save_quizzes(self, input, result_dir, filename_prefix): + def trivial_prompts_and_answers(self, prompts, answers): pass - # returns a pair (forward_tokens, backward_token) - def direction_tokens(self): + # The one to implement, returns two tensors nb x D and nb x D' + def generate_prompts_and_answers_(self, nb): pass + + # save a file to vizualize quizzes, you can save a txt or png file + def save_quizzes( + self, + result_dir, + filename_prefix, + prompts, + answers, + predicted_prompts=None, + predicted_answers=None, + ): + pass + + def fill_cache(self): + while True: + prompts, answers = self.generate_prompts_and_answers_(self.chunk_size) + + self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True) + + def generate_prompts_and_answers(self, nb): + if self.queue is None: + return self.generate_prompts_and_answers_(nb) + + if self.rest is not None: + prompts, answers = rest + else: + prompts, answers = [], [] + + self.rest = None + + n = sum([p.size(0) for p in prompts]) + + with tqdm.tqdm( + total=nb, + dynamic_ncols=True, + desc="world generation", + ) as pbar: + while n < nb: + p, s = self.queue.get(block=True) + prompts.append(p) + answers.append(s) + n += p.size(0) + pbar.update(p.size(0)) + + prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0) + assert n == prompts.size(0) + + k = n - nb + + if k > 0: + rest = (prompts[-k:], answers[-k:]) + prompts, answers = prompts[:-k], answers[:-k] + + return prompts, answers