Update.
[culture.git] / problem.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import threading, queue, torch, tqdm
9
10
11 class Problem:
12     def nb_token_values(self):
13         pass
14
15     def trivial_prompts_and_answers(self, prompts, answers):
16         pass
17
18     # returns two tensors nb x D and nb x D'
19     def generate_prompts_and_answers(self, nb):
20         pass
21
22     # save a file to vizualize quizzes, you can save a txt or png file
23     def save_quizzes(
24         self,
25         result_dir,
26         filename_prefix,
27         prompts,
28         answers,
29         predicted_prompts=None,
30         predicted_answers=None,
31     ):
32         pass
33
34
35 class MultiThreadProblem:
36     def __init__(self, problem, max_nb_cached_chunks, chunk_size, nb_threads=1):
37         self.problem = problem
38         self.chunk_size = chunk_size
39         self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
40         for _ in range(nb_threads):
41             threading.Thread(target=self.fill_cache, daemon=True).start()
42         self.rest = None
43
44     def nb_token_values(self):
45         return self.problem.nb_token_values()
46
47     def save_quizzes(
48         self,
49         result_dir,
50         filename_prefix,
51         prompts,
52         answers,
53         predicted_prompts=None,
54         predicted_answers=None,
55     ):
56         self.problem.save_quizzes(
57             result_dir,
58             filename_prefix,
59             prompts,
60             answers,
61             predicted_prompts=None,
62             predicted_answers=None,
63         )
64
65     def fill_cache(self):
66         while True:
67             prompts, answers = self.problem.generate_prompts_and_answers(
68                 self.chunk_size
69             )
70
71             self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
72
73     def trivial_prompts_and_answers(self, prompts, answers):
74         return self.problem.trivial_prompts_and_answers(prompts, answers)
75
76     def generate_prompts_and_answers(self, nb):
77         if self.rest is not None:
78             prompts, answers = rest
79         else:
80             prompts, answers = [], []
81
82         self.rest = None
83
84         n = sum([p.size(0) for p in prompts])
85
86         with tqdm.tqdm(
87             total=nb,
88             dynamic_ncols=True,
89             desc="world generation",
90         ) as pbar:
91             while n < nb:
92                 p, s = self.queue.get(block=True)
93                 prompts.append(p)
94                 answers.append(s)
95                 n += p.size(0)
96                 pbar.update(p.size(0))
97
98         prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
99         assert n == prompts.size(0)
100
101         k = n - nb
102
103         if k > 0:
104             rest = (prompts[-k:], answers[-k:])
105             prompts, answers = prompts[:-k], answers[:-k]
106
107         return prompts, answers