# Written by Francois Fleuret <francois@fleuret.org>
-import math, os, tqdm, warnings
+import math, os, tqdm, warnings, sys
import torch, torchvision
import mygpt
from mygpt import BracketedSequence
+import threading
+
+######################################################################
+# if output is log(P(X=y)) and target is Y, returns -log P(X=Y) + H(X
+# | X != Y)
+
+
+# output is NxCxT and target is NxT
+def confusion(output, target, reduction="mean"):
+ N, C, T = output.shape
+ output = output.permute(0, 2, 1).reshape(-1, C)
+ target = target.flatten()
+ all_t = torch.arange(N * T, device=output.device)
+ output = output.log_softmax(dim=-1)
+ result = -output[all_t, target]
+
+ output[all_t, target] = float("-inf")
+ output = output.log_softmax(dim=-1)
+ e = output.exp()
+ output[all_t, target] = 0
+ result = result - (output * e).sum(-1)
+
+ if reduction == "none":
+ return result.reshape(N, T)
+ elif reduction == "mean":
+ return result.reshape(N, T).mean()
+ elif reduction == "sum":
+ return result.reshape(N, T).sum()
+ else:
+ raise ValueError(f"unknown reduction '{reduction}'.")
+
+
######################################################################
# ar_mask is a tensor with 0s and 1s, of same shape as input, with
input,
ar_mask,
seq_logproba,
- temperature=1.0,
- deterministic_synthesis=False,
+ temperature,
+ deterministic_synthesis,
):
to_generate = (ar_mask.sum(0) > 0).nonzero()
t_next = dist.sample()
all_n = torch.arange(t_next.size(0))
- seq_logproba += logits[all_n, t_next].sum(dim=-1)
+
+ seq_logproba += logits[all_n, t_next]
input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
).all()
return i_forward, i_backward
+ def non_trivial(self, quizzes):
+ quizzes = quizzes.clone()
+ n_forward = quizzes[quizzes[:, 0] == self.token_forward]
+ n_backward = quizzes[:, 0] == self.token_backward
+ backward = quizzes[n_backward]
+ quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
+ return torch.logical_not(
+ self.problem.trivial_prompts_and_answers(
+ quizzes[:, 1 : 1 + self.prompt_len],
+ quizzes[:, 2 + self.prompt_len :],
+ )
+ )
+
def reverse_time(self, quizzes):
i_forward, i_backward = self.indices_forward_and_backward(quizzes)
self.prompt_len = None
self.answer_len = None
- self.train_w_quizzes = self.generate_token_sequences(nb_train_samples)
- self.reverse_random_half_in_place(self.train_w_quizzes)
- self.train_w_quizzes = self.train_w_quizzes.to(device)
-
- self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device)
- self.reverse_random_half_in_place(self.test_w_quizzes)
- self.test_w_quizzes = self.test_w_quizzes.to(device)
-
+ self.LOCK_C_QUIZZES = threading.Lock()
self.train_c_quizzes = []
self.test_c_quizzes = []
- if result_dir is not None:
- self.save_quizzes(
- result_dir,
- "culture_w_quizzes",
- self.train_w_quizzes[:72],
- )
-
- def save_quizzes(
+ def save_quiz_illustrations(
self,
result_dir,
filename_prefix,
quizzes,
mistakes=None,
):
- quizzes = quizzes.clone()
+ quizzes = quizzes.clone().to("cpu")
n_forward = quizzes[quizzes[:, 0] == self.token_forward]
n_backward = quizzes[:, 0] == self.token_backward
backward = quizzes[n_backward]
predicted_answers = 1 - predicted_prompts
if mistakes is not None:
# 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
- predicted_prompts *= mistakes
- predicted_answers *= mistakes
+ predicted_prompts *= mistakes.to("cpu")
+ predicted_answers *= mistakes.to("cpu")
else:
# 0/2 ~ not-to-predict / to predict
predicted_prompts *= 2
predicted_answers *= 2
- self.problem.save_quizzes(
+ self.problem.save_quiz_illustrations(
result_dir,
filename_prefix,
quizzes[:, 1 : 1 + self.prompt_len],
predicted_answers,
)
- def batches(self, split="train", desc=None):
- assert split in {"train", "test"}
- if split == "train":
- w_quizzes = self.train_w_quizzes
- c_quizzes = self.train_c_quizzes
- else:
- w_quizzes = self.test_w_quizzes
- c_quizzes = self.test_c_quizzes
+ def vocabulary_size(self):
+ return self.nb_token_values
- if len(c_quizzes) > 0:
- c_quizzes = torch.cat(c_quizzes, dim=0)
- if c_quizzes.size(0) > w_quizzes.size(0) // 2:
- i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
- c_quizzes = c_quizzes[i]
+ ######################################################################
- i = torch.randperm(w_quizzes.size(0))[
- : w_quizzes.size(0) - c_quizzes.size(0)
- ]
- w_quizzes = w_quizzes[i]
+ def batches(self, model, split="train", desc=None):
+ assert split in {"train", "test"}
+
+ with self.LOCK_C_QUIZZES:
+ if split == "train":
+ w_quizzes = model.train_w_quizzes
+ c_quizzes = self.train_c_quizzes
+ else:
+ w_quizzes = model.test_w_quizzes
+ c_quizzes = self.test_c_quizzes
+
+ if len(c_quizzes) > 0:
+ c_quizzes = torch.cat(c_quizzes, dim=0)
+ if c_quizzes.size(0) > w_quizzes.size(0) // 2:
+ i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
+ c_quizzes = c_quizzes[i]
+
+ i = torch.randperm(w_quizzes.size(0))[
+ : w_quizzes.size(0) - c_quizzes.size(0)
+ ]
+ w_quizzes = w_quizzes[i]
- self.nb_batch_w_quizzes = w_quizzes.size(0)
- self.nb_batch_c_quizzes = c_quizzes.size(0)
+ self.nb_batch_w_quizzes = w_quizzes.size(0)
+ self.nb_batch_c_quizzes = c_quizzes.size(0)
- input = torch.cat([w_quizzes, c_quizzes], dim=0)
- else:
- input = w_quizzes
- self.nb_batch_w_quizzes = w_quizzes.size(0)
- self.nb_batch_c_quizzes = 0
+ input = torch.cat([w_quizzes, c_quizzes], dim=0)
+ else:
+ input = w_quizzes
+ self.nb_batch_w_quizzes = w_quizzes.size(0)
+ self.nb_batch_c_quizzes = 0
# Shuffle
input = input[torch.randperm(input.size(0))]
):
yield batch
- def vocabulary_size(self):
- return self.nb_token_values
+ ######################################################################
def produce_results(
self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
):
def compute_accuracy(input, log_prefix=None):
+ input = input.to(self.device)
ar_mask = self.make_ar_mask(input)
result = input.clone() * (1 - ar_mask)
seq_logproba = torch.empty(input.size(0), device=self.device)
backward_nb_total = correct[n_backward].size(0)
self.logger(
- f"{log_prefix}_forward_accuracy {n_epoch} model {model.id} nb_correct {forward_nb_correct} / {forward_nb_total}"
- )
-
- self.logger(
- f"{log_prefix}_backward_accuracy {n_epoch} model {model.id} nb_correct {backward_nb_correct} / {backward_nb_total}"
+ f"{log_prefix}_accuracy {n_epoch} model {model.id} forward {forward_nb_correct} / {forward_nb_total} backward {backward_nb_correct} / {backward_nb_total}"
)
return result, correct
- compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train")
+ # compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train")
test_result, test_correct = compute_accuracy(
- self.test_w_quizzes[:nmax], log_prefix="test"
+ model.test_w_quizzes[:nmax], log_prefix="test"
)
main_test_accuracy = test_correct.sum() / test_correct.size(0)
##############################
- self.save_quizzes(
+ self.save_quiz_illustrations(
result_dir,
f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
quizzes=test_result[:72],
return main_test_accuracy
- def renew_w_quizzes(self, nb, for_train=True):
- input = self.train_w_quizzes if for_train else self.test_w_quizzes
+ ######################################################################
+
+ def renew_w_quizzes(self, model, nb, for_train=True):
+ input = model.train_w_quizzes if for_train else model.test_w_quizzes
nb = min(nb, input.size(0))
input[:-nb] = input[nb:].clone()
fresh_w_quizzes = self.generate_token_sequences(nb)
self.reverse_random_half_in_place(fresh_w_quizzes)
- input[-nb:] = fresh_w_quizzes.to(self.device)
+ input[-nb:] = fresh_w_quizzes.to("cpu")
+
+ ######################################################################
def store_c_quizzes(self, new_c_quizzes, for_train=True):
- if for_train:
- self.train_c_quizzes.append(new_c_quizzes)
- else:
- self.test_c_quizzes.append(new_c_quizzes)
+ with self.LOCK_C_QUIZZES:
+ if for_train:
+ self.train_c_quizzes.append(new_c_quizzes.to("cpu"))
+ else:
+ self.test_c_quizzes.append(new_c_quizzes.to("cpu"))
+
+ def save_c_quizzes(self, filename):
+ torch.save((self.train_c_quizzes, self.test_c_quizzes), filename)
+
+ def load_c_quizzes(self, filename):
+ self.train_c_quizzes, self.test_c_quizzes = torch.load(filename)
+
+ ######################################################################
+
+ def logproba_of_solutions(self, models, c_quizzes):
+ logproba = c_quizzes.new_zeros(
+ c_quizzes.size(0), len(models), device=self.device, dtype=torch.float32
+ )
+
+ for model in models:
+ with torch.autograd.no_grad():
+ t = model.training
+ model.eval()
+
+ for input, l in zip(
+ c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
+ ):
+ input = input.to(self.device)
+ ar_mask = self.make_ar_mask(input)
+ output = model(mygpt.BracketedSequence(input)).x
+ ce = (
+ F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+ * ar_mask
+ )
+ l[:, model.id] = -ce.sum(dim=-1)
+
+ model.train(t)
+
+ return logproba.to("cpu")
+
+ ###############################################################
def compute_correctness(
self,
nb_correct = 0
+ seq_logproba[...] = 0.0
+
for model in models_for_validation:
result = c_quizzes.clone()
- seq_logproba[...] = 0.0
-
ar_mask = self.make_ar_mask(result)
masked_inplace_autoregression(
def generate_quizzes(self, nb, model_for_generation, temperature=1.0):
c_quizzes = torch.empty(
- nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
+ nb,
+ self.prompt_len + self.answer_len + 2,
+ device=self.device,
+ dtype=torch.int64,
)
seq_logproba = torch.zeros(nb, device=self.device)
device=self.device,
)
- return c_quizzes
+ return c_quizzes.to("cpu")