Update.
[culture.git] / problem.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import threading, queue, torch
9
10
11 class Problem:
12     def nb_token_values(self):
13         pass
14
15     def trivial_prompts_and_answers(self, prompts, answers):
16         pass
17
18     # returns two tensors nb x D and nb x D'
19     def generate_prompts_and_answers(self, nb):
20         pass
21
22     # save a file to vizualize quizzes, you can save a txt or png file
23     def save_quizzes(
24         self,
25         result_dir,
26         filename_prefix,
27         prompts,
28         answers,
29         predicted_prompts=None,
30         predicted_answers=None,
31     ):
32         pass
33
34
35 class MultiThreadProblem:
36     def __init__(self, problem, max_nb_cached_chunks, chunk_size):
37         self.problem = problem
38         self.chunk_size = chunk_size
39         self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
40         threading.Thread(target=self.fill_cache, daemon=True).start()
41         self.rest = None
42
43     def nb_token_values(self):
44         return self.problem.nb_token_values()
45
46     def save_quizzes(
47         self,
48         result_dir,
49         filename_prefix,
50         prompts,
51         answers,
52         predicted_prompts=None,
53         predicted_answers=None,
54     ):
55         self.problem.save_quizzes(
56             result_dir,
57             filename_prefix,
58             prompts,
59             answers,
60             predicted_prompts=None,
61             predicted_answers=None,
62         )
63
64     def fill_cache(self):
65         while True:
66             prompts, answers = self.problem.generate_prompts_and_answers(
67                 self.chunk_size
68             )
69
70             self.queue.put((prompts, answers), block=True)
71
72     def trivial_prompts_and_answers(self, prompts, answers):
73         return self.problem.trivial_prompts_and_answers(prompts, answers)
74
75     def generate_prompts_and_answers(self, nb):
76         if self.rest is not None:
77             prompts, answers = rest
78         else:
79             prompts, answers = [], []
80
81         self.rest = None
82
83         n = sum([p.size(0) for p in prompts])
84
85         while n < nb:
86             p, s = self.queue.get(block=True)
87             prompts.append(p)
88             answers.append(s)
89             n += p.size(0)
90
91         prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
92
93         k = n - nb
94
95         if k > 0:
96             rest = (prompts[-k:], answers[-k:])
97             prompts, answers = prompts[:-k], answers[:-k]
98
99         return prompts, answers