Merge branch 'dev'
[culture.git] / problem.py
index 617b2a8..9bee5b2 100755 (executable)
@@ -25,18 +25,63 @@ class Problem:
         else:
             return self.queue.qsize() * self.chunk_size
 
-    def nb_token_values(self):
-        pass
+    def fill_cache(self):
+        while True:
+            quizzes = self.generate_w_quizzes_(self.chunk_size)
+            self.queue.put(quizzes.to("cpu"), block=True)
+
+    def generate_w_quizzes(self, nb, progress_bar=True):
+        if self.queue is None:
+            return self.generate_w_quizzes_(nb)
+
+        if self.rest is not None:
+            quizzes = rest
+        else:
+            quizzes = []
+
+        self.rest = None
+
+        n = sum([q.size(0) for q in quizzes])
+
+        if progress_bar:
+            with tqdm.tqdm(
+                total=nb,
+                dynamic_ncols=True,
+                desc="world generation",
+            ) as pbar:
+                while n < nb:
+                    q = self.queue.get(block=True)
+                    quizzes.append(q)
+                    n += q.size(0)
+                    pbar.update(q.size(0))
+        else:
+            while n < nb:
+                q = self.queue.get(block=True)
+                quizzes.append(q)
+                n += q.size(0)
+
+        quizzes = torch.cat(quizzes, dim=0)
+        assert n == quizzes.size(0)
+
+        k = n - nb
+
+        if k > 0:
+            rest = quizzes[-k:]
+            quizzes = quizzes[:-k]
+
+        return quizzes
+
+    ######################################################################
 
     def trivial_prompts_and_answers(self, prompts, answers):
         pass
 
     # The one to implement, returns two tensors nb x D and nb x D'
-    def generate_prompts_and_answers_(self, nb):
+    def generate_w_quizzes_(self, nb):
         pass
 
     # save a file to vizualize quizzes, you can save a txt or png file
-    def save_quizzes(
+    def save_quiz_illustrations(
         self,
         result_dir,
         filename_prefix,
@@ -47,44 +92,7 @@ class Problem:
     ):
         pass
 
-    def fill_cache(self):
-        while True:
-            prompts, answers = self.generate_prompts_and_answers_(self.chunk_size)
-
-            self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
-
-    def generate_prompts_and_answers(self, nb):
-        if self.queue is None:
-            return self.generate_prompts_and_answers_(nb)
-
-        if self.rest is not None:
-            prompts, answers = rest
-        else:
-            prompts, answers = [], []
-
-        self.rest = None
-
-        n = sum([p.size(0) for p in prompts])
-
-        with tqdm.tqdm(
-            total=nb,
-            dynamic_ncols=True,
-            desc="world generation",
-        ) as pbar:
-            while n < nb:
-                p, s = self.queue.get(block=True)
-                prompts.append(p)
-                answers.append(s)
-                n += p.size(0)
-                pbar.update(p.size(0))
-
-        prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
-        assert n == prompts.size(0)
-
-        k = n - nb
-
-        if k > 0:
-            rest = (prompts[-k:], answers[-k:])
-            prompts, answers = prompts[:-k], answers[:-k]
+    def save_some_examples(self, result_dir):
+        pass
 
-        return prompts, answers
+    ######################################################################