Update.
[culture.git] / problem.py
index 0bc83a1..7dd60dc 100755 (executable)
@@ -5,11 +5,16 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
+import threading, queue, torch
+
 
 class Problem:
     def nb_token_values(self):
         pass
 
+    def trivial_prompts_and_answers(self, prompts, answers):
+        pass
+
     # returns two tensors nb x D and nb x D'
     def generate_prompts_and_answers(self, nb):
         pass
@@ -21,7 +26,74 @@ class Problem:
         filename_prefix,
         prompts,
         answers,
-        predicted_prompt=None,
+        predicted_prompts=None,
         predicted_answers=None,
     ):
         pass
+
+
+class MultiThreadProblem:
+    def __init__(self, problem, max_nb_cached_chunks, chunk_size):
+        self.problem = problem
+        self.chunk_size = chunk_size
+        self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
+        threading.Thread(target=self.fill_cache, daemon=True).start()
+        self.rest = None
+
+    def nb_token_values(self):
+        return self.problem.nb_token_values()
+
+    def save_quizzes(
+        self,
+        result_dir,
+        filename_prefix,
+        prompts,
+        answers,
+        predicted_prompts=None,
+        predicted_answers=None,
+    ):
+        self.problem.save_quizzes(
+            result_dir,
+            filename_prefix,
+            prompts,
+            answers,
+            predicted_prompts=None,
+            predicted_answers=None,
+        )
+
+    def fill_cache(self):
+        while True:
+            prompts, answers = self.problem.generate_prompts_and_answers(
+                self.chunk_size
+            )
+
+            self.queue.put((prompts, answers), block=True)
+
+    def trivial_prompts_and_answers(self, prompts, answers):
+        return self.problem.trivial_prompts_and_answers(prompts, answers)
+
+    def generate_prompts_and_answers(self, nb):
+        if self.rest is not None:
+            prompts, answers = rest
+        else:
+            prompts, answers = [], []
+
+        self.rest = None
+
+        n = sum([p.size(0) for p in prompts])
+
+        while n < nb:
+            p, s = self.queue.get(block=True)
+            prompts.append(p)
+            answers.append(s)
+            n += p.size(0)
+
+        prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
+
+        k = n - nb
+
+        if k > 0:
+            rest = (prompts[-k:], answers[-k:])
+            prompts, answers = prompts[:-k], answers[:-k]
+
+        return prompts, answers