05f3b20984165acab83fd3ae563e6be896a14d6a
[culture.git] / problem.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import threading, queue, torch, tqdm
9
10
11 class Problem:
12     def __init__(self, max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1):
13         if nb_threads > 0:
14             self.chunk_size = chunk_size
15             self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
16             for _ in range(nb_threads):
17                 threading.Thread(target=self.fill_cache, daemon=True).start()
18             self.rest = None
19         else:
20             self.queue = None
21
22     def nb_cached_quizzes(self):
23         if self.queue is None:
24             return None
25         else:
26             return self.queue.qsize() * self.chunk_size
27
28     def nb_token_values(self):
29         pass
30
31     def trivial_prompts_and_answers(self, prompts, answers):
32         pass
33
34     # The one to implement, returns two tensors nb x D and nb x D'
35     def generate_prompts_and_answers_(self, nb):
36         pass
37
38     # save a file to vizualize quizzes, you can save a txt or png file
39     def save_quiz_illustrations(
40         self,
41         result_dir,
42         filename_prefix,
43         prompts,
44         answers,
45         predicted_prompts=None,
46         predicted_answers=None,
47     ):
48         pass
49
50     def fill_cache(self):
51         while True:
52             prompts, answers = self.generate_prompts_and_answers_(self.chunk_size)
53
54             self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
55
56     def generate_prompts_and_answers(self, nb):
57         if self.queue is None:
58             return self.generate_prompts_and_answers_(nb)
59
60         if self.rest is not None:
61             prompts, answers = rest
62         else:
63             prompts, answers = [], []
64
65         self.rest = None
66
67         n = sum([p.size(0) for p in prompts])
68
69         with tqdm.tqdm(
70             total=nb,
71             dynamic_ncols=True,
72             desc="world generation",
73         ) as pbar:
74             while n < nb:
75                 p, s = self.queue.get(block=True)
76                 prompts.append(p)
77                 answers.append(s)
78                 n += p.size(0)
79                 pbar.update(p.size(0))
80
81         prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
82         assert n == prompts.size(0)
83
84         k = n - nb
85
86         if k > 0:
87             rest = (prompts[-k:], answers[-k:])
88             prompts, answers = prompts[:-k], answers[:-k]
89
90         return prompts, answers
91
92     def save_some_examples(self, result_dir):
93         pass