X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quizz_machine.py;h=d591d79c53de0f87944eda6d4f7b1527769d5190;hb=fcb71a73da3a27f81383e3000b9ad1ee8da45926;hp=88f2c1c3ed9805bc53d94910c3a405004ae74274;hpb=c32e471f8153ce4bdf19fb440f1e642eda4b972a;p=culture.git diff --git a/quizz_machine.py b/quizz_machine.py index 88f2c1c..d591d79 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -398,4 +398,16 @@ class QuizzMachine: device=self.device, ) + c_quizzes = self.reverse_time(c_quizzes) + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=ar_mask_solve, + seq_logproba=seq_logproba, + temperature=temperature, + deterministic_synthesis=True, + device=self.device, + ) + return c_quizzes, seq_logproba.mean()