Update.
[culture.git] / problem.py
index a49634d..7eeb6b4 100755 (executable)
@@ -9,14 +9,30 @@ import threading, queue, torch, tqdm
 
 
 class Problem:
+    def __init__(self, max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1):
+        if nb_threads > 0:
+            self.chunk_size = chunk_size
+            self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
+            for _ in range(nb_threads):
+                threading.Thread(target=self.fill_cache, daemon=True).start()
+            self.rest = None
+        else:
+            self.queue = None
+
+    def nb_cached_quizzes(self):
+        if self.queue is None:
+            return None
+        else:
+            return self.queue.qsize() * self.chunk_size
+
     def nb_token_values(self):
         pass
 
     def trivial_prompts_and_answers(self, prompts, answers):
         pass
 
-    # returns two tensors nb x D and nb x D'
-    def generate_prompts_and_answers(self, nb):
+    # The one to implement, returns two tensors nb x D and nb x D'
+    def generate_prompts_and_answers_(self, nb):
         pass
 
     # save a file to vizualize quizzes, you can save a txt or png file
@@ -31,49 +47,16 @@ class Problem:
     ):
         pass
 
-
-class MultiThreadProblem:
-    def __init__(self, problem, max_nb_cached_chunks, chunk_size, nb_threads=1):
-        self.problem = problem
-        self.chunk_size = chunk_size
-        self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
-        for _ in range(nb_threads):
-            threading.Thread(target=self.fill_cache, daemon=True).start()
-        self.rest = None
-
-    def nb_token_values(self):
-        return self.problem.nb_token_values()
-
-    def save_quizzes(
-        self,
-        result_dir,
-        filename_prefix,
-        prompts,
-        answers,
-        predicted_prompts=None,
-        predicted_answers=None,
-    ):
-        self.problem.save_quizzes(
-            result_dir,
-            filename_prefix,
-            prompts,
-            answers,
-            predicted_prompts=None,
-            predicted_answers=None,
-        )
-
     def fill_cache(self):
         while True:
-            prompts, answers = self.problem.generate_prompts_and_answers(
-                self.chunk_size
-            )
+            prompts, answers = self.generate_prompts_and_answers_(self.chunk_size)
 
             self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
 
-    def trivial_prompts_and_answers(self, prompts, answers):
-        return self.problem.trivial_prompts_and_answers(prompts, answers)
-
     def generate_prompts_and_answers(self, nb):
+        if self.queue is None:
+            return self.generate_prompts_and_answers_(nb)
+
         if self.rest is not None:
             prompts, answers = rest
         else:
@@ -105,3 +88,6 @@ class MultiThreadProblem:
             prompts, answers = prompts[:-k], answers[:-k]
 
         return prompts, answers
+
+    def save_some_examples(self, result_dir):
+        pass