train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
- log_string(f"train_perplexity {n_epoch} {train_perplexity}")
+ log_string(f"train_perplexity {n_epoch} model.id {model.id} {train_perplexity}")
run_tests(model, quiz_machine, deterministic_synthesis=False)
def standard_validity(logproba):
l = logproba.sort(dim=-1).values
return (l[:, 0] < math.log(0.5)) & (l[:, 1] > math.log(0.99))
- # warnings.warn("TEST!!!", RuntimeWarning)
- # print(l.exp())
- # return (l[:, 0] < math.log(0.99))
def valid_c_quizzes(recorded, criteria):
model.id = k
model.TRAINING_LOCK = threading.Lock()
- model.train_w_quizzes = quiz_machine.generate_token_sequences(
- args.nb_train_samples
- ).to(device)
+ model.train_w_quizzes = quiz_machine.generate_token_sequences(args.nb_train_samples)
quiz_machine.reverse_random_half_in_place(model.train_w_quizzes)
- model.test_w_quizzes = quiz_machine.generate_token_sequences(
- args.nb_test_samples
- ).to(device)
+ model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples)
quiz_machine.reverse_random_half_in_place(model.test_w_quizzes)
models.append(model)
nb_new_c_quizzes_for_train = 100
nb_new_c_quizzes_for_test = 10
+ def standard_validity(logproba):
+ l = logproba.sort(dim=-1).values
+ return l[:, 0] < math.log(0.5)
+
+
######################################################################
for n_epoch in range(args.nb_epochs):
self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
):
def compute_accuracy(input, log_prefix=None):
+ input = input.to(self.device)
ar_mask = self.make_ar_mask(input)
result = input.clone() * (1 - ar_mask)
seq_logproba = torch.empty(input.size(0), device=self.device)
input[:-nb] = input[nb:].clone()
fresh_w_quizzes = self.generate_token_sequences(nb)
self.reverse_random_half_in_place(fresh_w_quizzes)
- input[-nb:] = fresh_w_quizzes.to(self.device)
+ input[-nb:] = fresh_w_quizzes.to("cpu")
######################################################################
def store_c_quizzes(self, new_c_quizzes, for_train=True):
with self.LOCK_C_QUIZZES:
if for_train:
- self.train_c_quizzes.append(new_c_quizzes)
+ self.train_c_quizzes.append(new_c_quizzes.to("cpu"))
else:
- self.test_c_quizzes.append(new_c_quizzes)
+ self.test_c_quizzes.append(new_c_quizzes.to("cpu"))
######################################################################
def logproba_of_solutions(self, models, c_quizzes):
- logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models))
+ logproba = c_quizzes.new_zeros(
+ c_quizzes.size(0), len(models), device=self.device
+ )
for model in models:
for input, l in zip(
c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
):
+ input = input.to(self.device)
ar_mask = self.make_ar_mask(input)
output = model(mygpt.BracketedSequence(input)).x
ce = (
)
l[:, model.id] = -ce.sum(dim=-1)
- return logproba
+ return logproba.to("cpu")
###############################################################
device=self.device,
)
- return c_quizzes
+ return c_quizzes.to("cpu")