X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=problem.py;h=05f3b20984165acab83fd3ae563e6be896a14d6a;hb=refs%2Fheads%2Fmaster;hp=0795de193dff1afd7ddb44910cb34f1e184b62db;hpb=bfcef9a8c82ed45528601e85725166241bbee916;p=culture.git diff --git a/problem.py b/problem.py index 0795de1..9bee5b2 100755 --- a/problem.py +++ b/problem.py @@ -5,17 +5,94 @@ # Written by Francois Fleuret +import threading, queue, torch, tqdm + class Problem: - # returns a nb x (L+1+L) long tensor where L is the length of one - # of the two states of a quizz - def generate_token_sequences(self, nb): + def __init__(self, max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1): + if nb_threads > 0: + self.chunk_size = chunk_size + self.queue = queue.Queue(maxsize=max_nb_cached_chunks) + for _ in range(nb_threads): + threading.Thread(target=self.fill_cache, daemon=True).start() + self.rest = None + else: + self.queue = None + + def nb_cached_quizzes(self): + if self.queue is None: + return None + else: + return self.queue.qsize() * self.chunk_size + + def fill_cache(self): + while True: + quizzes = self.generate_w_quizzes_(self.chunk_size) + self.queue.put(quizzes.to("cpu"), block=True) + + def generate_w_quizzes(self, nb, progress_bar=True): + if self.queue is None: + return self.generate_w_quizzes_(nb) + + if self.rest is not None: + quizzes = rest + else: + quizzes = [] + + self.rest = None + + n = sum([q.size(0) for q in quizzes]) + + if progress_bar: + with tqdm.tqdm( + total=nb, + dynamic_ncols=True, + desc="world generation", + ) as pbar: + while n < nb: + q = self.queue.get(block=True) + quizzes.append(q) + n += q.size(0) + pbar.update(q.size(0)) + else: + while n < nb: + q = self.queue.get(block=True) + quizzes.append(q) + n += q.size(0) + + quizzes = torch.cat(quizzes, dim=0) + assert n == quizzes.size(0) + + k = n - nb + + if k > 0: + rest = quizzes[-k:] + quizzes = quizzes[:-k] + + return quizzes + + ###################################################################### + + def trivial_prompts_and_answers(self, prompts, answers): + pass + + # The one to implement, returns two tensors nb x D and nb x D' + def generate_w_quizzes_(self, nb): pass # save a file to vizualize quizzes, you can save a txt or png file - def save_quizzes(self, input, result_dir, filename_prefix): + def save_quiz_illustrations( + self, + result_dir, + filename_prefix, + prompts, + answers, + predicted_prompts=None, + predicted_answers=None, + ): pass - # returns a pair (forward_tokens, backward_token) - def direction_tokens(self): + def save_some_examples(self, result_dir): pass + + ######################################################################