import threading
+import torch.multiprocessing as mp
+
# world quizzes vs. culture quizzes
######################################################################
)
for model in models:
- for input, l in zip(
- c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
- ):
- input = input.to(self.device)
- ar_mask = self.make_ar_mask(input)
- output = model(mygpt.BracketedSequence(input)).x
- ce = (
- F.cross_entropy(output.transpose(1, 2), input, reduction="none")
- * ar_mask
- )
- l[:, model.id] = -ce.sum(dim=-1)
+ with torch.autograd.no_grad():
+ t = model.training
+ model.eval()
+
+ for input, l in zip(
+ c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
+ ):
+ input = input.to(self.device)
+ ar_mask = self.make_ar_mask(input)
+ output = model(mygpt.BracketedSequence(input)).x
+ ce = (
+ F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+ * ar_mask
+ )
+ l[:, model.id] = -ce.sum(dim=-1)
+
+ model.train(t)
return logproba.to("cpu")