Update.
[culture.git] / problem.py
index 7dd60dc..617b2a8 100755 (executable)
@@ -5,18 +5,34 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
-import threading, queue, torch
+import threading, queue, torch, tqdm
 
 
 class Problem:
 
 
 class Problem:
+    def __init__(self, max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1):
+        if nb_threads > 0:
+            self.chunk_size = chunk_size
+            self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
+            for _ in range(nb_threads):
+                threading.Thread(target=self.fill_cache, daemon=True).start()
+            self.rest = None
+        else:
+            self.queue = None
+
+    def nb_cached_quizzes(self):
+        if self.queue is None:
+            return None
+        else:
+            return self.queue.qsize() * self.chunk_size
+
     def nb_token_values(self):
         pass
 
     def trivial_prompts_and_answers(self, prompts, answers):
         pass
 
     def nb_token_values(self):
         pass
 
     def trivial_prompts_and_answers(self, prompts, answers):
         pass
 
-    # returns two tensors nb x D and nb x D'
-    def generate_prompts_and_answers(self, nb):
+    # The one to implement, returns two tensors nb x D and nb x D'
+    def generate_prompts_and_answers_(self, nb):
         pass
 
     # save a file to vizualize quizzes, you can save a txt or png file
         pass
 
     # save a file to vizualize quizzes, you can save a txt or png file
@@ -31,48 +47,16 @@ class Problem:
     ):
         pass
 
     ):
         pass
 
-
-class MultiThreadProblem:
-    def __init__(self, problem, max_nb_cached_chunks, chunk_size):
-        self.problem = problem
-        self.chunk_size = chunk_size
-        self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
-        threading.Thread(target=self.fill_cache, daemon=True).start()
-        self.rest = None
-
-    def nb_token_values(self):
-        return self.problem.nb_token_values()
-
-    def save_quizzes(
-        self,
-        result_dir,
-        filename_prefix,
-        prompts,
-        answers,
-        predicted_prompts=None,
-        predicted_answers=None,
-    ):
-        self.problem.save_quizzes(
-            result_dir,
-            filename_prefix,
-            prompts,
-            answers,
-            predicted_prompts=None,
-            predicted_answers=None,
-        )
-
     def fill_cache(self):
         while True:
     def fill_cache(self):
         while True:
-            prompts, answers = self.problem.generate_prompts_and_answers(
-                self.chunk_size
-            )
+            prompts, answers = self.generate_prompts_and_answers_(self.chunk_size)
 
 
-            self.queue.put((prompts, answers), block=True)
-
-    def trivial_prompts_and_answers(self, prompts, answers):
-        return self.problem.trivial_prompts_and_answers(prompts, answers)
+            self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
 
     def generate_prompts_and_answers(self, nb):
 
     def generate_prompts_and_answers(self, nb):
+        if self.queue is None:
+            return self.generate_prompts_and_answers_(nb)
+
         if self.rest is not None:
             prompts, answers = rest
         else:
         if self.rest is not None:
             prompts, answers = rest
         else:
@@ -82,13 +66,20 @@ class MultiThreadProblem:
 
         n = sum([p.size(0) for p in prompts])
 
 
         n = sum([p.size(0) for p in prompts])
 
-        while n < nb:
-            p, s = self.queue.get(block=True)
-            prompts.append(p)
-            answers.append(s)
-            n += p.size(0)
+        with tqdm.tqdm(
+            total=nb,
+            dynamic_ncols=True,
+            desc="world generation",
+        ) as pbar:
+            while n < nb:
+                p, s = self.queue.get(block=True)
+                prompts.append(p)
+                answers.append(s)
+                n += p.size(0)
+                pbar.update(p.size(0))
 
         prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
 
         prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
+        assert n == prompts.size(0)
 
         k = n - nb
 
 
         k = n - nb