Update.
[culture.git] / problem.py
index 0bc83a1..7eeb6b4 100755 (executable)
@@ -5,13 +5,34 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
+import threading, queue, torch, tqdm
+
 
 class Problem:
+    def __init__(self, max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1):
+        if nb_threads > 0:
+            self.chunk_size = chunk_size
+            self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
+            for _ in range(nb_threads):
+                threading.Thread(target=self.fill_cache, daemon=True).start()
+            self.rest = None
+        else:
+            self.queue = None
+
+    def nb_cached_quizzes(self):
+        if self.queue is None:
+            return None
+        else:
+            return self.queue.qsize() * self.chunk_size
+
     def nb_token_values(self):
         pass
 
-    # returns two tensors nb x D and nb x D'
-    def generate_prompts_and_answers(self, nb):
+    def trivial_prompts_and_answers(self, prompts, answers):
+        pass
+
+    # The one to implement, returns two tensors nb x D and nb x D'
+    def generate_prompts_and_answers_(self, nb):
         pass
 
     # save a file to vizualize quizzes, you can save a txt or png file
@@ -21,7 +42,52 @@ class Problem:
         filename_prefix,
         prompts,
         answers,
-        predicted_prompt=None,
+        predicted_prompts=None,
         predicted_answers=None,
     ):
         pass
+
+    def fill_cache(self):
+        while True:
+            prompts, answers = self.generate_prompts_and_answers_(self.chunk_size)
+
+            self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
+
+    def generate_prompts_and_answers(self, nb):
+        if self.queue is None:
+            return self.generate_prompts_and_answers_(nb)
+
+        if self.rest is not None:
+            prompts, answers = rest
+        else:
+            prompts, answers = [], []
+
+        self.rest = None
+
+        n = sum([p.size(0) for p in prompts])
+
+        with tqdm.tqdm(
+            total=nb,
+            dynamic_ncols=True,
+            desc="world generation",
+        ) as pbar:
+            while n < nb:
+                p, s = self.queue.get(block=True)
+                prompts.append(p)
+                answers.append(s)
+                n += p.size(0)
+                pbar.update(p.size(0))
+
+        prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
+        assert n == prompts.size(0)
+
+        k = n - nb
+
+        if k > 0:
+            rest = (prompts[-k:], answers[-k:])
+            prompts, answers = prompts[:-k], answers[:-k]
+
+        return prompts, answers
+
+    def save_some_examples(self, result_dir):
+        pass