######################################################################
parser = argparse.ArgumentParser(
- description="An implementation of GPT with cache.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
log_string(f"vocabulary_size {vocabulary_size}")
######################################################################
-
-# Compute the entropy of the training tokens
-
-token_count = 0
-for input in quiz_machine.batches(split="train", desc="train-entropy"):
- token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum(
- (0, 1)
- )
-token_probas = token_count / token_count.sum()
-entropy = -torch.xlogy(token_probas, token_probas).sum()
-train_set_perplexity = math.exp(entropy)
-
-######################################################################
-# A bit of paranoia never hurts
-
-if args.max_percents_of_test_in_train >= 0:
-
- def subsets_as_tuples(batches, cs):
- s = set()
- for batch in batches:
- for x in batch:
- s.add(tuple([v.item() for v in x]))
- if len(s) == cs:
- yield s
- s = set()
- yield s
-
- nb_test, nb_in_train = 0, 0
- for test_subset in subsets_as_tuples(
- quiz_machine.batches(split="test", desc="test-check"), 25000
- ):
- in_train = set()
- for train_subset in subsets_as_tuples(
- quiz_machine.batches(split="train", desc="train-check"), 25000
- ):
- in_train.update(test_subset.intersection(train_subset))
- nb_in_train += len(in_train)
- nb_test += len(test_subset)
-
- log_string(
- f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
- )
-
- assert (
- nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
- ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
-
##############################
nb_train_samples, acc_train_loss = 0, 0.0
- for input in quiz_machine.batches(split="train"):
+ for input in quiz_machine.batches(model, split="train"):
input = input.to(device)
if nb_train_samples % args.batch_size == 0:
nb_test_samples, acc_test_loss = 0, 0.0
nb_samples_accumulated = 0
- for input in quiz_machine.batches(split="test"):
+ for input in quiz_machine.batches(model, split="test"):
input = input.to(device)
bs = model(mygpt.BracketedSequence(input))
model.main_test_accuracy = 0.0
model.id = k
+ model.train_w_quizzes = quiz_machine.generate_token_sequences(
+ args.nb_train_samples
+ ).to(device)
+ quiz_machine.reverse_random_half_in_place(model.train_w_quizzes)
+ model.test_w_quizzes = quiz_machine.generate_token_sequences(
+ args.nb_test_samples
+ ).to(device)
+ quiz_machine.reverse_random_half_in_place(model.test_w_quizzes)
+
models.append(model)
######################################################################
+# Compute the entropy of the training tokens
+
+token_count = 0
+for input in quiz_machine.batches(models[0], split="train", desc="train-entropy"):
+ token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum(
+ (0, 1)
+ )
+token_probas = token_count / token_count.sum()
+entropy = -torch.xlogy(token_probas, token_probas).sum()
+train_set_perplexity = math.exp(entropy)
+
+######################################################################
+# A bit of paranoia never hurts
+
+if args.max_percents_of_test_in_train >= 0:
+
+ def subsets_as_tuples(batches, cs):
+ s = set()
+ for batch in batches:
+ for x in batch:
+ s.add(tuple([v.item() for v in x]))
+ if len(s) == cs:
+ yield s
+ s = set()
+ yield s
+
+ nb_test, nb_in_train = 0, 0
+ for test_subset in subsets_as_tuples(
+ quiz_machine.batches(models[0], split="test", desc="test-check"), 25000
+ ):
+ in_train = set()
+ for train_subset in subsets_as_tuples(
+ quiz_machine.batches(models[0], split="train", desc="train-check"), 25000
+ ):
+ in_train.update(test_subset.intersection(train_subset))
+ nb_in_train += len(in_train)
+ nb_test += len(test_subset)
+
+ log_string(
+ f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
+ )
+
+ assert (
+ nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
+ ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
+
+######################################################################
+
nb_new_c_quizzes_for_train = args.nb_train_samples // 50
nb_new_c_quizzes_for_test = args.nb_test_samples // 50
log_string(
f"cache_w_quizzes contains {quiz_machine.problem.nb_cached_quizzes()} quizzes"
)
- quiz_machine.renew_w_quizzes(args.nb_train_samples // args.nb_gpts)
+ quiz_machine.renew_w_quizzes(model, args.nb_train_samples // args.nb_gpts)
##################################################
# If all the models are good enough, generate new quizzes and
self.prompt_len = None
self.answer_len = None
- self.train_w_quizzes = self.generate_token_sequences(nb_train_samples)
- self.reverse_random_half_in_place(self.train_w_quizzes)
- self.train_w_quizzes = self.train_w_quizzes.to(device)
+ # self.train_w_quizzes = self.generate_token_sequences(nb_train_samples)
+ # self.reverse_random_half_in_place(self.train_w_quizzes)
- self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device)
- self.reverse_random_half_in_place(self.test_w_quizzes)
- self.test_w_quizzes = self.test_w_quizzes.to(device)
+ # self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device)
+ # self.reverse_random_half_in_place(self.test_w_quizzes)
self.train_c_quizzes = []
self.test_c_quizzes = []
- if result_dir is not None:
- self.save_quizzes(
- result_dir,
- "culture_w_quizzes",
- self.train_w_quizzes[:72],
- )
+ # if result_dir is not None:
+ # self.save_quizzes(
+ # result_dir,
+ # "culture_w_quizzes",
+ # self.train_w_quizzes[:72],
+ # )
def save_quizzes(
self,
predicted_answers,
)
- def batches(self, split="train", desc=None):
+ def batches(self, model, split="train", desc=None):
assert split in {"train", "test"}
if split == "train":
- w_quizzes = self.train_w_quizzes
+ w_quizzes = model.train_w_quizzes
c_quizzes = self.train_c_quizzes
else:
- w_quizzes = self.test_w_quizzes
+ w_quizzes = model.test_w_quizzes
c_quizzes = self.test_c_quizzes
if len(c_quizzes) > 0:
return result, correct
- compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train")
+ compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train")
test_result, test_correct = compute_accuracy(
- self.test_w_quizzes[:nmax], log_prefix="test"
+ model.test_w_quizzes[:nmax], log_prefix="test"
)
main_test_accuracy = test_correct.sum() / test_correct.size(0)
return main_test_accuracy
- def renew_w_quizzes(self, nb, for_train=True):
- input = self.train_w_quizzes if for_train else self.test_w_quizzes
+ def renew_w_quizzes(self, model, nb, for_train=True):
+ input = model.train_w_quizzes if for_train else model.test_w_quizzes
nb = min(nb, input.size(0))
input[:-nb] = input[nb:].clone()
fresh_w_quizzes = self.generate_token_sequences(nb)
def generate_quizzes(self, nb, model_for_generation, temperature=1.0):
c_quizzes = torch.empty(
- nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
+ nb,
+ self.prompt_len + self.answer_len + 2,
+ device=self.device,
+ dtype=torch.int64,
)
seq_logproba = torch.zeros(nb, device=self.device)