Merge branch 'dev'
[culture.git] / quiz_machine.py
index 8ab5696..bc468d3 100755 (executable)
@@ -5,7 +5,7 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
-import math, os, tqdm, warnings
+import math, os, tqdm, warnings, sys
 
 import torch, torchvision
 
@@ -17,6 +17,36 @@ from mygpt import BracketedSequence
 
 import threading
 
+######################################################################
+# if output is log(P(X=y)) and target is Y, returns -log P(X=Y) + H(X
+# | X != Y)
+
+
+# output is NxCxT and target is NxT
+def confusion(output, target, reduction="mean"):
+    N, C, T = output.shape
+    output = output.permute(0, 2, 1).reshape(-1, C)
+    target = target.flatten()
+    all_t = torch.arange(N * T, device=output.device)
+    output = output.log_softmax(dim=-1)
+    result = -output[all_t, target]
+
+    output[all_t, target] = float("-inf")
+    output = output.log_softmax(dim=-1)
+    e = output.exp()
+    output[all_t, target] = 0
+    result = result - (output * e).sum(-1)
+
+    if reduction == "none":
+        return result.reshape(N, T)
+    elif reduction == "mean":
+        return result.reshape(N, T).mean()
+    elif reduction == "sum":
+        return result.reshape(N, T).sum()
+    else:
+        raise ValueError(f"unknown reduction '{reduction}'.")
+
+
 ######################################################################
 
 # ar_mask is a tensor with 0s and 1s, of same shape as input, with
@@ -241,7 +271,7 @@ class QuizMachine:
         self.train_c_quizzes = []
         self.test_c_quizzes = []
 
-    def save_quizzes(
+    def save_quiz_illustrations(
         self,
         result_dir,
         filename_prefix,
@@ -266,7 +296,7 @@ class QuizMachine:
             predicted_prompts *= 2
             predicted_answers *= 2
 
-        self.problem.save_quizzes(
+        self.problem.save_quiz_illustrations(
             result_dir,
             filename_prefix,
             quizzes[:, 1 : 1 + self.prompt_len],
@@ -368,16 +398,12 @@ class QuizMachine:
                 backward_nb_total = correct[n_backward].size(0)
 
                 self.logger(
-                    f"{log_prefix}_forward_accuracy {n_epoch} model {model.id} nb_correct {forward_nb_correct} / {forward_nb_total} ({forward_nb_correct*100/forward_nb_total} %)"
-                )
-
-                self.logger(
-                    f"{log_prefix}_backward_accuracy {n_epoch} model {model.id} nb_correct {backward_nb_correct} / {backward_nb_total} ({backward_nb_correct*100/backward_nb_total} %)"
+                    f"{log_prefix}_accuracy {n_epoch} model {model.id} forward {forward_nb_correct} / {forward_nb_total} backward {backward_nb_correct} / {backward_nb_total}"
                 )
 
             return result, correct
 
-        compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train")
+        compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train")
 
         test_result, test_correct = compute_accuracy(
             model.test_w_quizzes[:nmax], log_prefix="test"
@@ -388,7 +414,7 @@ class QuizMachine:
 
         ##############################
 
-        self.save_quizzes(
+        self.save_quiz_illustrations(
             result_dir,
             f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
             quizzes=test_result[:72],
@@ -416,11 +442,21 @@ class QuizMachine:
             else:
                 self.test_c_quizzes.append(new_c_quizzes.to("cpu"))
 
+    def save_c_quizzes(self, filename):
+        torch.save((self.train_c_quizzes, self.test_c_quizzes), filename)
+
+    def load_c_quizzes(self, filename):
+        self.train_c_quizzes, self.test_c_quizzes = torch.load(filename)
+
     ######################################################################
 
-    def logproba_of_solutions(self, models, c_quizzes):
+    def solution_token_logprobas(self, models, c_quizzes):
         logproba = c_quizzes.new_zeros(
-            c_quizzes.size(0), len(models), device=self.device
+            c_quizzes.size(0),
+            len(models),
+            c_quizzes.size(1),
+            device=self.device,
+            dtype=torch.float32,
         )
 
         for model in models:
@@ -434,11 +470,12 @@ class QuizMachine:
                     input = input.to(self.device)
                     ar_mask = self.make_ar_mask(input)
                     output = model(mygpt.BracketedSequence(input)).x
-                    ce = (
-                        F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+                    l[:, model.id] = (
+                        -F.cross_entropy(
+                            output.transpose(1, 2), input, reduction="none"
+                        )
                         * ar_mask
                     )
-                    l[:, model.id] = -ce.sum(dim=-1)
 
                 model.train(t)