Update.
[culture.git] / problem.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import threading, queue, torch, tqdm
9
10
11 class Problem:
12     def __init__(self, max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1):
13         if nb_threads > 0:
14             self.chunk_size = chunk_size
15             self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
16             for _ in range(nb_threads):
17                 threading.Thread(target=self.fill_cache, daemon=True).start()
18             self.rest = None
19         else:
20             self.queue = None
21
22     def nb_token_values(self):
23         pass
24
25     def trivial_prompts_and_answers(self, prompts, answers):
26         pass
27
28     # The one to implement, returns two tensors nb x D and nb x D'
29     def generate_prompts_and_answers_(self, nb):
30         pass
31
32     # save a file to vizualize quizzes, you can save a txt or png file
33     def save_quizzes(
34         self,
35         result_dir,
36         filename_prefix,
37         prompts,
38         answers,
39         predicted_prompts=None,
40         predicted_answers=None,
41     ):
42         pass
43
44     def fill_cache(self):
45         while True:
46             prompts, answers = self.generate_prompts_and_answers_(self.chunk_size)
47
48             self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
49
50     def generate_prompts_and_answers(self, nb):
51         if self.queue is None:
52             return self.generate_prompts_and_answers_(nb)
53
54         if self.rest is not None:
55             prompts, answers = rest
56         else:
57             prompts, answers = [], []
58
59         self.rest = None
60
61         n = sum([p.size(0) for p in prompts])
62
63         with tqdm.tqdm(
64             total=nb,
65             dynamic_ncols=True,
66             desc="world generation",
67         ) as pbar:
68             while n < nb:
69                 p, s = self.queue.get(block=True)
70                 prompts.append(p)
71                 answers.append(s)
72                 n += p.size(0)
73                 pbar.update(p.size(0))
74
75         prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
76         assert n == prompts.size(0)
77
78         k = n - nb
79
80         if k > 0:
81             rest = (prompts[-k:], answers[-k:])
82             prompts, answers = prompts[:-k], answers[:-k]
83
84         return prompts, answers