X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quiz_machine.py;h=c49ecf2b71f2508dd01f6075de225e29153be387;hb=5b7022591f48382ec84b1dda17297b1ed15166d7;hp=88fd9f1dfea42ebfdbd99f290243b66d45c8f639;hpb=7b716a85786247b292ee71a635c98a18c66b421d;p=culture.git diff --git a/quiz_machine.py b/quiz_machine.py index 88fd9f1..c49ecf2 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, os, tqdm, warnings +import math, os, tqdm, warnings, sys import torch, torchvision @@ -17,6 +17,36 @@ from mygpt import BracketedSequence import threading +###################################################################### +# if output is log(P(X=y)) and target is Y, returns -log P(X=Y) + H(X +# | X != Y) + + +# output is NxCxT and target is NxT +def confusion(output, target, reduction="mean"): + N, C, T = output.shape + output = output.permute(0, 2, 1).reshape(-1, C) + target = target.flatten() + all_t = torch.arange(N * T, device=output.device) + output = output.log_softmax(dim=-1) + result = -output[all_t, target] + + output[all_t, target] = float("-inf") + output = output.log_softmax(dim=-1) + e = output.exp() + output[all_t, target] = 0 + result = result - (output * e).sum(-1) + + if reduction == "none": + return result.reshape(N, T) + elif reduction == "mean": + return result.reshape(N, T).mean() + elif reduction == "sum": + return result.reshape(N, T).sum() + else: + raise ValueError(f"unknown reduction '{reduction}'.") + + ###################################################################### # ar_mask is a tensor with 0s and 1s, of same shape as input, with @@ -373,7 +403,7 @@ class QuizMachine: return result, correct - compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train") + # compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train") test_result, test_correct = compute_accuracy( model.test_w_quizzes[:nmax], log_prefix="test" @@ -412,6 +442,12 @@ class QuizMachine: else: self.test_c_quizzes.append(new_c_quizzes.to("cpu")) + def save_c_quizzes(self, filename): + torch.save((self.train_c_quizzes, self.test_c_quizzes), filename) + + def load_c_quizzes(self, filename): + self.train_c_quizzes, self.test_c_quizzes = torch.load(filename) + ###################################################################### def logproba_of_solutions(self, models, c_quizzes):