Merge branch 'dev'
[culture.git] / problem.py
index 0795de1..05f3b20 100755 (executable)
@@ -5,17 +5,89 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
+import threading, queue, torch, tqdm
+
 
 class Problem:
 
 class Problem:
-    # returns a nb x (L+1+L) long tensor where L is the length of one
-    # of the two states of a quizz
-    def generate_token_sequences(self, nb):
+    def __init__(self, max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1):
+        if nb_threads > 0:
+            self.chunk_size = chunk_size
+            self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
+            for _ in range(nb_threads):
+                threading.Thread(target=self.fill_cache, daemon=True).start()
+            self.rest = None
+        else:
+            self.queue = None
+
+    def nb_cached_quizzes(self):
+        if self.queue is None:
+            return None
+        else:
+            return self.queue.qsize() * self.chunk_size
+
+    def nb_token_values(self):
+        pass
+
+    def trivial_prompts_and_answers(self, prompts, answers):
+        pass
+
+    # The one to implement, returns two tensors nb x D and nb x D'
+    def generate_prompts_and_answers_(self, nb):
         pass
 
     # save a file to vizualize quizzes, you can save a txt or png file
         pass
 
     # save a file to vizualize quizzes, you can save a txt or png file
-    def save_quizzes(self, input, result_dir, filename_prefix):
+    def save_quiz_illustrations(
+        self,
+        result_dir,
+        filename_prefix,
+        prompts,
+        answers,
+        predicted_prompts=None,
+        predicted_answers=None,
+    ):
         pass
 
         pass
 
-    # returns a pair (forward_tokens, backward_token)
-    def direction_tokens(self):
+    def fill_cache(self):
+        while True:
+            prompts, answers = self.generate_prompts_and_answers_(self.chunk_size)
+
+            self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
+
+    def generate_prompts_and_answers(self, nb):
+        if self.queue is None:
+            return self.generate_prompts_and_answers_(nb)
+
+        if self.rest is not None:
+            prompts, answers = rest
+        else:
+            prompts, answers = [], []
+
+        self.rest = None
+
+        n = sum([p.size(0) for p in prompts])
+
+        with tqdm.tqdm(
+            total=nb,
+            dynamic_ncols=True,
+            desc="world generation",
+        ) as pbar:
+            while n < nb:
+                p, s = self.queue.get(block=True)
+                prompts.append(p)
+                answers.append(s)
+                n += p.size(0)
+                pbar.update(p.size(0))
+
+        prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
+        assert n == prompts.size(0)
+
+        k = n - nb
+
+        if k > 0:
+            rest = (prompts[-k:], answers[-k:])
+            prompts, answers = prompts[:-k], answers[:-k]
+
+        return prompts, answers
+
+    def save_some_examples(self, result_dir):
         pass
         pass