X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=quiz_machine.py;h=45b2247f50519d382bc5e04bbab5612fb60fe698;hb=719785dbea77989a54bf7592bb6919f2e8f3f6c5;hp=d4af77027d73c0017efe8cbecfaad8fcfc468cad;hpb=ceddc8cc3adbb045fdef1ccb0b3df2b8fed9eb4c;p=culture.git diff --git a/quiz_machine.py b/quiz_machine.py index d4af770..45b2247 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -50,7 +50,8 @@ def one_batch_masked_inplace_autoregression( t_next = dist.sample() all_n = torch.arange(t_next.size(0)) - seq_logproba += logits[all_n, t_next].sum(dim=-1) + + seq_logproba += logits[all_n, t_next] input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] @@ -359,11 +360,11 @@ class QuizMachine: backward_nb_total = correct[n_backward].size(0) self.logger( - f"{log_prefix}_forward_accuracy {n_epoch} {model.id=} {forward_nb_correct} / {forward_nb_total}" + f"{log_prefix}_forward_accuracy {n_epoch} model {model.id} nb_correct {forward_nb_correct} / {forward_nb_total} ({forward_nb_correct*100/forward_nb_total} %)" ) self.logger( - f"{log_prefix}_backward_accuracy {n_epoch} {model.id=} {backward_nb_correct} / {backward_nb_total}" + f"{log_prefix}_backward_accuracy {n_epoch} model {model.id} nb_correct {backward_nb_correct} / {backward_nb_total} ({backward_nb_correct*100/backward_nb_total} %)" ) return result, correct @@ -420,11 +421,11 @@ class QuizMachine: nb_correct = 0 + seq_logproba[...] = 0.0 + for model in models_for_validation: result = c_quizzes.clone() - seq_logproba[...] = 0.0 - ar_mask = self.make_ar_mask(result) masked_inplace_autoregression(