X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=problem.py;h=0bc83a12148c5fec2af3ed7b96edfce3b86054d5;hb=15e704a200286551a8e9c1765a0340c370367dee;hp=95a9c4180be0bbac89bd19b8329b80c22942a72f;hpb=4f0057b363762698f90eea05de154e62b6883bd0;p=culture.git diff --git a/problem.py b/problem.py index 95a9c41..0bc83a1 100755 --- a/problem.py +++ b/problem.py @@ -7,15 +7,21 @@ class Problem: - # returns a nb x (L+1+L) long tensor where L is the length of one - # of the two states of a quizz - def generate_seq(self, nb): + def nb_token_values(self): pass - # save a file to vizualize quizzes, you can save a txt or png file - def save_quizzes(self, input, result_dir, filename_prefix): + # returns two tensors nb x D and nb x D' + def generate_prompts_and_answers(self, nb): pass - # returns a pair (forward_tokens, backward_token) - def direction_tokens(self): + # save a file to vizualize quizzes, you can save a txt or png file + def save_quizzes( + self, + result_dir, + filename_prefix, + prompts, + answers, + predicted_prompt=None, + predicted_answers=None, + ): pass