import mygpt
from mygpt import BracketedSequence
+import threading
+
######################################################################
# ar_mask is a tensor with 0s and 1s, of same shape as input, with
self.prompt_len = None
self.answer_len = None
- # self.train_w_quizzes = self.generate_token_sequences(nb_train_samples)
- # self.reverse_random_half_in_place(self.train_w_quizzes)
-
- # self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device)
- # self.reverse_random_half_in_place(self.test_w_quizzes)
-
+ self.LOCK_C_QUIZZES = threading.Lock()
self.train_c_quizzes = []
self.test_c_quizzes = []
- # if result_dir is not None:
- # self.save_quizzes(
- # result_dir,
- # "culture_w_quizzes",
- # self.train_w_quizzes[:72],
- # )
-
def save_quizzes(
self,
result_dir,
predicted_answers,
)
+ def vocabulary_size(self):
+ return self.nb_token_values
+
+ ######################################################################
+
def batches(self, model, split="train", desc=None):
assert split in {"train", "test"}
- if split == "train":
- w_quizzes = model.train_w_quizzes
- c_quizzes = self.train_c_quizzes
- else:
- w_quizzes = model.test_w_quizzes
- c_quizzes = self.test_c_quizzes
- if len(c_quizzes) > 0:
- c_quizzes = torch.cat(c_quizzes, dim=0)
- if c_quizzes.size(0) > w_quizzes.size(0) // 2:
- i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
- c_quizzes = c_quizzes[i]
-
- i = torch.randperm(w_quizzes.size(0))[
- : w_quizzes.size(0) - c_quizzes.size(0)
- ]
- w_quizzes = w_quizzes[i]
+ with self.LOCK_C_QUIZZES:
+ if split == "train":
+ w_quizzes = model.train_w_quizzes
+ c_quizzes = self.train_c_quizzes
+ else:
+ w_quizzes = model.test_w_quizzes
+ c_quizzes = self.test_c_quizzes
+
+ if len(c_quizzes) > 0:
+ c_quizzes = torch.cat(c_quizzes, dim=0)
+ if c_quizzes.size(0) > w_quizzes.size(0) // 2:
+ i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
+ c_quizzes = c_quizzes[i]
+
+ i = torch.randperm(w_quizzes.size(0))[
+ : w_quizzes.size(0) - c_quizzes.size(0)
+ ]
+ w_quizzes = w_quizzes[i]
- self.nb_batch_w_quizzes = w_quizzes.size(0)
- self.nb_batch_c_quizzes = c_quizzes.size(0)
+ self.nb_batch_w_quizzes = w_quizzes.size(0)
+ self.nb_batch_c_quizzes = c_quizzes.size(0)
- input = torch.cat([w_quizzes, c_quizzes], dim=0)
- else:
- input = w_quizzes
- self.nb_batch_w_quizzes = w_quizzes.size(0)
- self.nb_batch_c_quizzes = 0
+ input = torch.cat([w_quizzes, c_quizzes], dim=0)
+ else:
+ input = w_quizzes
+ self.nb_batch_w_quizzes = w_quizzes.size(0)
+ self.nb_batch_c_quizzes = 0
# Shuffle
input = input[torch.randperm(input.size(0))]
):
yield batch
- def vocabulary_size(self):
- return self.nb_token_values
+ ######################################################################
def produce_results(
self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
):
def compute_accuracy(input, log_prefix=None):
+ input = input.to(self.device)
ar_mask = self.make_ar_mask(input)
result = input.clone() * (1 - ar_mask)
seq_logproba = torch.empty(input.size(0), device=self.device)
return main_test_accuracy
+ ######################################################################
+
def renew_w_quizzes(self, model, nb, for_train=True):
input = model.train_w_quizzes if for_train else model.test_w_quizzes
nb = min(nb, input.size(0))
input[:-nb] = input[nb:].clone()
fresh_w_quizzes = self.generate_token_sequences(nb)
self.reverse_random_half_in_place(fresh_w_quizzes)
- input[-nb:] = fresh_w_quizzes.to(self.device)
+ input[-nb:] = fresh_w_quizzes.to("cpu")
+
+ ######################################################################
def store_c_quizzes(self, new_c_quizzes, for_train=True):
- if for_train:
- self.train_c_quizzes.append(new_c_quizzes)
- else:
- self.test_c_quizzes.append(new_c_quizzes)
+ with self.LOCK_C_QUIZZES:
+ if for_train:
+ self.train_c_quizzes.append(new_c_quizzes.to("cpu"))
+ else:
+ self.test_c_quizzes.append(new_c_quizzes.to("cpu"))
- def logproba_solution(self, models, c_quizzes):
- logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models))
+ ######################################################################
- for model in models:
- for input, l in zip(
- c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
- ):
- ar_mask = self.make_ar_mask(input)
- output = model(mygpt.BracketedSequence(input)).x
- ce = (
- F.cross_entropy(output.transpose(1, 2), input, reduction="none")
- * ar_mask
- )
- l[:, model.id] = -ce.sum(dim=-1)
+ def logproba_of_solutions(self, models, c_quizzes):
+ logproba = c_quizzes.new_zeros(
+ c_quizzes.size(0), len(models), device=self.device
+ )
- return logproba
+ for model in models:
+ with torch.autograd.no_grad():
+ t = model.training
+ model.eval()
+
+ for input, l in zip(
+ c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
+ ):
+ input = input.to(self.device)
+ ar_mask = self.make_ar_mask(input)
+ output = model(mygpt.BracketedSequence(input)).x
+ ce = (
+ F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+ * ar_mask
+ )
+ l[:, model.id] = -ce.sum(dim=-1)
+
+ model.train(t)
+
+ return logproba.to("cpu")
###############################################################
device=self.device,
)
- return c_quizzes
+ return c_quizzes.to("cpu")