Merge branch 'dev'
[culture.git] / problem.py
index 95a9c41..9bee5b2 100755 (executable)
@@ -5,17 +5,94 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
+import threading, queue, torch, tqdm
+
 
 class Problem:
-    # returns a nb x (L+1+L) long tensor where L is the length of one
-    # of the two states of a quizz
-    def generate_seq(self, nb):
+    def __init__(self, max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1):
+        if nb_threads > 0:
+            self.chunk_size = chunk_size
+            self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
+            for _ in range(nb_threads):
+                threading.Thread(target=self.fill_cache, daemon=True).start()
+            self.rest = None
+        else:
+            self.queue = None
+
+    def nb_cached_quizzes(self):
+        if self.queue is None:
+            return None
+        else:
+            return self.queue.qsize() * self.chunk_size
+
+    def fill_cache(self):
+        while True:
+            quizzes = self.generate_w_quizzes_(self.chunk_size)
+            self.queue.put(quizzes.to("cpu"), block=True)
+
+    def generate_w_quizzes(self, nb, progress_bar=True):
+        if self.queue is None:
+            return self.generate_w_quizzes_(nb)
+
+        if self.rest is not None:
+            quizzes = rest
+        else:
+            quizzes = []
+
+        self.rest = None
+
+        n = sum([q.size(0) for q in quizzes])
+
+        if progress_bar:
+            with tqdm.tqdm(
+                total=nb,
+                dynamic_ncols=True,
+                desc="world generation",
+            ) as pbar:
+                while n < nb:
+                    q = self.queue.get(block=True)
+                    quizzes.append(q)
+                    n += q.size(0)
+                    pbar.update(q.size(0))
+        else:
+            while n < nb:
+                q = self.queue.get(block=True)
+                quizzes.append(q)
+                n += q.size(0)
+
+        quizzes = torch.cat(quizzes, dim=0)
+        assert n == quizzes.size(0)
+
+        k = n - nb
+
+        if k > 0:
+            rest = quizzes[-k:]
+            quizzes = quizzes[:-k]
+
+        return quizzes
+
+    ######################################################################
+
+    def trivial_prompts_and_answers(self, prompts, answers):
+        pass
+
+    # The one to implement, returns two tensors nb x D and nb x D'
+    def generate_w_quizzes_(self, nb):
         pass
 
     # save a file to vizualize quizzes, you can save a txt or png file
-    def save_quizzes(self, input, result_dir, filename_prefix):
+    def save_quiz_illustrations(
+        self,
+        result_dir,
+        filename_prefix,
+        prompts,
+        answers,
+        predicted_prompts=None,
+        predicted_answers=None,
+    ):
         pass
 
-    # returns a pair (forward_tokens, backward_token)
-    def direction_tokens(self):
+    def save_some_examples(self, result_dir):
         pass
+
+    ######################################################################