Merge branch 'dev'
[culture.git] / problem.py
index 7dd60dc..9bee5b2 100755 (executable)
@@ -5,45 +5,83 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
-import threading, queue, torch
+import threading, queue, torch, tqdm
 
 
 class Problem:
-    def nb_token_values(self):
-        pass
+    def __init__(self, max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1):
+        if nb_threads > 0:
+            self.chunk_size = chunk_size
+            self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
+            for _ in range(nb_threads):
+                threading.Thread(target=self.fill_cache, daemon=True).start()
+            self.rest = None
+        else:
+            self.queue = None
 
-    def trivial_prompts_and_answers(self, prompts, answers):
-        pass
+    def nb_cached_quizzes(self):
+        if self.queue is None:
+            return None
+        else:
+            return self.queue.qsize() * self.chunk_size
 
-    # returns two tensors nb x D and nb x D'
-    def generate_prompts_and_answers(self, nb):
-        pass
+    def fill_cache(self):
+        while True:
+            quizzes = self.generate_w_quizzes_(self.chunk_size)
+            self.queue.put(quizzes.to("cpu"), block=True)
 
-    # save a file to vizualize quizzes, you can save a txt or png file
-    def save_quizzes(
-        self,
-        result_dir,
-        filename_prefix,
-        prompts,
-        answers,
-        predicted_prompts=None,
-        predicted_answers=None,
-    ):
-        pass
+    def generate_w_quizzes(self, nb, progress_bar=True):
+        if self.queue is None:
+            return self.generate_w_quizzes_(nb)
 
+        if self.rest is not None:
+            quizzes = rest
+        else:
+            quizzes = []
 
-class MultiThreadProblem:
-    def __init__(self, problem, max_nb_cached_chunks, chunk_size):
-        self.problem = problem
-        self.chunk_size = chunk_size
-        self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
-        threading.Thread(target=self.fill_cache, daemon=True).start()
         self.rest = None
 
-    def nb_token_values(self):
-        return self.problem.nb_token_values()
+        n = sum([q.size(0) for q in quizzes])
+
+        if progress_bar:
+            with tqdm.tqdm(
+                total=nb,
+                dynamic_ncols=True,
+                desc="world generation",
+            ) as pbar:
+                while n < nb:
+                    q = self.queue.get(block=True)
+                    quizzes.append(q)
+                    n += q.size(0)
+                    pbar.update(q.size(0))
+        else:
+            while n < nb:
+                q = self.queue.get(block=True)
+                quizzes.append(q)
+                n += q.size(0)
+
+        quizzes = torch.cat(quizzes, dim=0)
+        assert n == quizzes.size(0)
+
+        k = n - nb
+
+        if k > 0:
+            rest = quizzes[-k:]
+            quizzes = quizzes[:-k]
+
+        return quizzes
 
-    def save_quizzes(
+    ######################################################################
+
+    def trivial_prompts_and_answers(self, prompts, answers):
+        pass
+
+    # The one to implement, returns two tensors nb x D and nb x D'
+    def generate_w_quizzes_(self, nb):
+        pass
+
+    # save a file to vizualize quizzes, you can save a txt or png file
+    def save_quiz_illustrations(
         self,
         result_dir,
         filename_prefix,
@@ -52,48 +90,9 @@ class MultiThreadProblem:
         predicted_prompts=None,
         predicted_answers=None,
     ):
-        self.problem.save_quizzes(
-            result_dir,
-            filename_prefix,
-            prompts,
-            answers,
-            predicted_prompts=None,
-            predicted_answers=None,
-        )
-
-    def fill_cache(self):
-        while True:
-            prompts, answers = self.problem.generate_prompts_and_answers(
-                self.chunk_size
-            )
-
-            self.queue.put((prompts, answers), block=True)
-
-    def trivial_prompts_and_answers(self, prompts, answers):
-        return self.problem.trivial_prompts_and_answers(prompts, answers)
-
-    def generate_prompts_and_answers(self, nb):
-        if self.rest is not None:
-            prompts, answers = rest
-        else:
-            prompts, answers = [], []
-
-        self.rest = None
-
-        n = sum([p.size(0) for p in prompts])
-
-        while n < nb:
-            p, s = self.queue.get(block=True)
-            prompts.append(p)
-            answers.append(s)
-            n += p.size(0)
-
-        prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
-
-        k = n - nb
+        pass
 
-        if k > 0:
-            rest = (prompts[-k:], answers[-k:])
-            prompts, answers = prompts[:-k], answers[:-k]
+    def save_some_examples(self, result_dir):
+        pass
 
-        return prompts, answers
+    ######################################################################