X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=problem.py;h=0bc83a12148c5fec2af3ed7b96edfce3b86054d5;hb=15e704a200286551a8e9c1765a0340c370367dee;hp=8d973eb3d1dd7d917eeae4b47687d872fc4e6c8d;hpb=ff043757ea7d5d992a3d1fc4c435c1422997b1af;p=culture.git diff --git a/problem.py b/problem.py index 8d973eb..0bc83a1 100755 --- a/problem.py +++ b/problem.py @@ -7,15 +7,21 @@ class Problem: - # returns a nb x (L+1+L) long tensor where L is the length of one - # of the two states of a quizz - def generate_seq(self, nb): + def nb_token_values(self): pass - # save a file to vizualize quizzes, you can save a txt or png file - def save_quizzes(self, input, result_dir, filename_prefix, logger): + # returns two tensors nb x D and nb x D' + def generate_prompts_and_answers(self, nb): pass - # returns a pair (forward_tokens, backward_token) - def direction_tokens(self): + # save a file to vizualize quizzes, you can save a txt or png file + def save_quizzes( + self, + result_dir, + filename_prefix, + prompts, + answers, + predicted_prompt=None, + predicted_answers=None, + ): pass