Merge branch 'dev'
[culture.git] / problem.py
index a49634d..9bee5b2 100755 (executable)
@@ -9,99 +9,90 @@ import threading, queue, torch, tqdm
 
 
 class Problem:
-    def nb_token_values(self):
-        pass
-
-    def trivial_prompts_and_answers(self, prompts, answers):
-        pass
-
-    # returns two tensors nb x D and nb x D'
-    def generate_prompts_and_answers(self, nb):
-        pass
-
-    # save a file to vizualize quizzes, you can save a txt or png file
-    def save_quizzes(
-        self,
-        result_dir,
-        filename_prefix,
-        prompts,
-        answers,
-        predicted_prompts=None,
-        predicted_answers=None,
-    ):
-        pass
-
-
-class MultiThreadProblem:
-    def __init__(self, problem, max_nb_cached_chunks, chunk_size, nb_threads=1):
-        self.problem = problem
-        self.chunk_size = chunk_size
-        self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
-        for _ in range(nb_threads):
-            threading.Thread(target=self.fill_cache, daemon=True).start()
-        self.rest = None
-
-    def nb_token_values(self):
-        return self.problem.nb_token_values()
+    def __init__(self, max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1):
+        if nb_threads > 0:
+            self.chunk_size = chunk_size
+            self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
+            for _ in range(nb_threads):
+                threading.Thread(target=self.fill_cache, daemon=True).start()
+            self.rest = None
+        else:
+            self.queue = None
 
-    def save_quizzes(
-        self,
-        result_dir,
-        filename_prefix,
-        prompts,
-        answers,
-        predicted_prompts=None,
-        predicted_answers=None,
-    ):
-        self.problem.save_quizzes(
-            result_dir,
-            filename_prefix,
-            prompts,
-            answers,
-            predicted_prompts=None,
-            predicted_answers=None,
-        )
+    def nb_cached_quizzes(self):
+        if self.queue is None:
+            return None
+        else:
+            return self.queue.qsize() * self.chunk_size
 
     def fill_cache(self):
         while True:
-            prompts, answers = self.problem.generate_prompts_and_answers(
-                self.chunk_size
-            )
+            quizzes = self.generate_w_quizzes_(self.chunk_size)
+            self.queue.put(quizzes.to("cpu"), block=True)
 
-            self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
+    def generate_w_quizzes(self, nb, progress_bar=True):
+        if self.queue is None:
+            return self.generate_w_quizzes_(nb)
 
-    def trivial_prompts_and_answers(self, prompts, answers):
-        return self.problem.trivial_prompts_and_answers(prompts, answers)
-
-    def generate_prompts_and_answers(self, nb):
         if self.rest is not None:
-            prompts, answers = rest
+            quizzes = rest
         else:
-            prompts, answers = [], []
+            quizzes = []
 
         self.rest = None
 
-        n = sum([p.size(0) for p in prompts])
-
-        with tqdm.tqdm(
-            total=nb,
-            dynamic_ncols=True,
-            desc="world generation",
-        ) as pbar:
+        n = sum([q.size(0) for q in quizzes])
+
+        if progress_bar:
+            with tqdm.tqdm(
+                total=nb,
+                dynamic_ncols=True,
+                desc="world generation",
+            ) as pbar:
+                while n < nb:
+                    q = self.queue.get(block=True)
+                    quizzes.append(q)
+                    n += q.size(0)
+                    pbar.update(q.size(0))
+        else:
             while n < nb:
-                p, s = self.queue.get(block=True)
-                prompts.append(p)
-                answers.append(s)
-                n += p.size(0)
-                pbar.update(p.size(0))
+                q = self.queue.get(block=True)
+                quizzes.append(q)
+                n += q.size(0)
 
-        prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
-        assert n == prompts.size(0)
+        quizzes = torch.cat(quizzes, dim=0)
+        assert n == quizzes.size(0)
 
         k = n - nb
 
         if k > 0:
-            rest = (prompts[-k:], answers[-k:])
-            prompts, answers = prompts[:-k], answers[:-k]
+            rest = quizzes[-k:]
+            quizzes = quizzes[:-k]
+
+        return quizzes
+
+    ######################################################################
+
+    def trivial_prompts_and_answers(self, prompts, answers):
+        pass
+
+    # The one to implement, returns two tensors nb x D and nb x D'
+    def generate_w_quizzes_(self, nb):
+        pass
+
+    # save a file to vizualize quizzes, you can save a txt or png file
+    def save_quiz_illustrations(
+        self,
+        result_dir,
+        filename_prefix,
+        prompts,
+        answers,
+        predicted_prompts=None,
+        predicted_answers=None,
+    ):
+        pass
+
+    def save_some_examples(self, result_dir):
+        pass
 
-        return prompts, answers
+    ######################################################################