return prompts.flatten(1), answers.flatten(1)
- def save_quizzes(
+ def save_quiz_illustrations(
self,
result_dir,
filename_prefix,
for t in self.all_tasks:
print(t.__name__)
prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t])
- self.save_quizzes(
+ self.save_quiz_illustrations(
result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow
)
]:
print(t.__name__)
prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
- grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow)
+ grids.save_quiz_illustrations(
+ "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow
+ )
exit(0)
predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
- grids.save_quizzes(
+ grids.save_quiz_illustrations(
"/tmp",
"test",
prompts[:nb],