else:
self.queue = None
- def nb_token_values(self):
- pass
+ def nb_cached_quizzes(self):
+ if self.queue is None:
+ return None
+ else:
+ return self.queue.qsize() * self.chunk_size
+
+ def fill_cache(self):
+ while True:
+ quizzes = self.generate_w_quizzes_(self.chunk_size)
+ self.queue.put(quizzes.to("cpu"), block=True)
+
+ def generate_w_quizzes(self, nb, progress_bar=True):
+ if self.queue is None:
+ return self.generate_w_quizzes_(nb)
+
+ if self.rest is not None:
+ quizzes = rest
+ else:
+ quizzes = []
+
+ self.rest = None
+
+ n = sum([q.size(0) for q in quizzes])
+
+ if progress_bar:
+ with tqdm.tqdm(
+ total=nb, dynamic_ncols=True, desc="world generation", delay=10
+ ) as pbar:
+ while n < nb:
+ q = self.queue.get(block=True)
+ quizzes.append(q)
+ n += q.size(0)
+ pbar.update(q.size(0))
+ else:
+ while n < nb:
+ q = self.queue.get(block=True)
+ quizzes.append(q)
+ n += q.size(0)
+
+ quizzes = torch.cat(quizzes, dim=0)
+ assert n == quizzes.size(0)
+
+ k = n - nb
+
+ if k > 0:
+ rest = quizzes[-k:]
+ quizzes = quizzes[:-k]
+
+ return quizzes
+
+ ######################################################################
def trivial_prompts_and_answers(self, prompts, answers):
pass
# The one to implement, returns two tensors nb x D and nb x D'
- def generate_prompts_and_answers_(self, nb):
+ def generate_w_quizzes_(self, nb):
pass
# save a file to vizualize quizzes, you can save a txt or png file
- def save_quizzes(
+ def save_quiz_illustrations(
self,
result_dir,
filename_prefix,
):
pass
- def fill_cache(self):
- while True:
- prompts, answers = self.generate_prompts_and_answers_(self.chunk_size)
-
- self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
-
- def generate_prompts_and_answers(self, nb):
- if self.queue is None:
- return self.generate_prompts_and_answers_(nb)
-
- if self.rest is not None:
- prompts, answers = rest
- else:
- prompts, answers = [], []
-
- self.rest = None
-
- n = sum([p.size(0) for p in prompts])
-
- with tqdm.tqdm(
- total=nb,
- dynamic_ncols=True,
- desc="world generation",
- ) as pbar:
- while n < nb:
- p, s = self.queue.get(block=True)
- prompts.append(p)
- answers.append(s)
- n += p.size(0)
- pbar.update(p.size(0))
-
- prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
- assert n == prompts.size(0)
-
- k = n - nb
-
- if k > 0:
- rest = (prompts[-k:], answers[-k:])
- prompts, answers = prompts[:-k], answers[:-k]
+ def save_some_examples(self, result_dir):
+ pass
- return prompts, answers
+ ######################################################################