X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quiz_machine.py;h=c49ecf2b71f2508dd01f6075de225e29153be387;hb=5b7022591f48382ec84b1dda17297b1ed15166d7;hp=4f704a0587ba5ffa282e1867f49c3535c7c152a1;hpb=12c775dcbd3d3cd703f35c181faa6d2a680a0450;p=culture.git diff --git a/quiz_machine.py b/quiz_machine.py index 4f704a0..c49ecf2 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, os, tqdm, warnings +import math, os, tqdm, warnings, sys import torch, torchvision @@ -17,6 +17,36 @@ from mygpt import BracketedSequence import threading +###################################################################### +# if output is log(P(X=y)) and target is Y, returns -log P(X=Y) + H(X +# | X != Y) + + +# output is NxCxT and target is NxT +def confusion(output, target, reduction="mean"): + N, C, T = output.shape + output = output.permute(0, 2, 1).reshape(-1, C) + target = target.flatten() + all_t = torch.arange(N * T, device=output.device) + output = output.log_softmax(dim=-1) + result = -output[all_t, target] + + output[all_t, target] = float("-inf") + output = output.log_softmax(dim=-1) + e = output.exp() + output[all_t, target] = 0 + result = result - (output * e).sum(-1) + + if reduction == "none": + return result.reshape(N, T) + elif reduction == "mean": + return result.reshape(N, T).mean() + elif reduction == "sum": + return result.reshape(N, T).sum() + else: + raise ValueError(f"unknown reduction '{reduction}'.") + + ###################################################################### # ar_mask is a tensor with 0s and 1s, of same shape as input, with @@ -241,7 +271,7 @@ class QuizMachine: self.train_c_quizzes = [] self.test_c_quizzes = [] - def save_quizzes( + def save_quiz_illustrations( self, result_dir, filename_prefix, @@ -266,7 +296,7 @@ class QuizMachine: predicted_prompts *= 2 predicted_answers *= 2 - self.problem.save_quizzes( + self.problem.save_quiz_illustrations( result_dir, filename_prefix, quizzes[:, 1 : 1 + self.prompt_len], @@ -373,7 +403,7 @@ class QuizMachine: return result, correct - compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train") + # compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train") test_result, test_correct = compute_accuracy( model.test_w_quizzes[:nmax], log_prefix="test" @@ -384,7 +414,7 @@ class QuizMachine: ############################## - self.save_quizzes( + self.save_quiz_illustrations( result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}", quizzes=test_result[:72], @@ -412,11 +442,17 @@ class QuizMachine: else: self.test_c_quizzes.append(new_c_quizzes.to("cpu")) + def save_c_quizzes(self, filename): + torch.save((self.train_c_quizzes, self.test_c_quizzes), filename) + + def load_c_quizzes(self, filename): + self.train_c_quizzes, self.test_c_quizzes = torch.load(filename) + ###################################################################### def logproba_of_solutions(self, models, c_quizzes): logproba = c_quizzes.new_zeros( - c_quizzes.size(0), len(models), device=self.device + c_quizzes.size(0), len(models), device=self.device, dtype=torch.float32 ) for model in models: