X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=problem.py;h=0bc83a12148c5fec2af3ed7b96edfce3b86054d5;hb=09c5eea203d5a2d8b1da84db0a336de151cf1c89;hp=354235ee4e6fc7b525694b63341137d5fefde475;hpb=3bb50442e264446ea94308aef86a4d2a7024417f;p=culture.git diff --git a/problem.py b/problem.py index 354235e..0bc83a1 100755 --- a/problem.py +++ b/problem.py @@ -7,14 +7,21 @@ class Problem: - # returns a nb x (L+1+L) long tensor - def generate_seq(self, nb): + def nb_token_values(self): pass - # save a file to vizualize quizzes, you can save a txt or png file - def save_quizzes(self, input, result_dir, filename_prefix, logger): + # returns two tensors nb x D and nb x D' + def generate_prompts_and_answers(self, nb): pass - # returns a pair (forward_tokens, backward_token) - def direction_tokens(self): + # save a file to vizualize quizzes, you can save a txt or png file + def save_quizzes( + self, + result_dir, + filename_prefix, + prompts, + answers, + predicted_prompt=None, + predicted_answers=None, + ): pass