-# Compute the entropy of the training tokens
-
-token_count = 0
-for input in quiz_machine.batches(split="train", desc="train-entropy"):
- token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum(
- (0, 1)
- )
-token_probas = token_count / token_count.sum()
-entropy = -torch.xlogy(token_probas, token_probas).sum()
-train_set_perplexity = math.exp(entropy)
-
-######################################################################
-# A bit of paranoia never hurts
-
-if args.max_percents_of_test_in_train >= 0:
-
- def subsets_as_tuples(batches, cs):
- s = set()
- for batch in batches:
- for x in batch:
- s.add(tuple([v.item() for v in x]))
- if len(s) == cs:
- yield s
- s = set()
- yield s
-
- nb_test, nb_in_train = 0, 0
- for test_subset in subsets_as_tuples(
- quiz_machine.batches(split="test", desc="test-check"), 25000
- ):
- in_train = set()
- for train_subset in subsets_as_tuples(
- quiz_machine.batches(split="train", desc="train-check"), 25000
- ):
- in_train.update(test_subset.intersection(train_subset))
- nb_in_train += len(in_train)
- nb_test += len(test_subset)
-
- log_string(
- f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
- )
-
- assert (
- nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
- ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
-
-##############################
-
-
-def one_epoch(model, quiz_machine):
- optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
-
- model.train()
-
- nb_train_samples, acc_train_loss = 0, 0.0
-
- for input in quiz_machine.batches(split="train"):
- input = input.to(device)
-
- if nb_train_samples % args.batch_size == 0:
- optimizer.zero_grad()
-
- output = model(mygpt.BracketedSequence(input)).x
- loss = F.cross_entropy(output.transpose(1, 2), input)
- acc_train_loss += loss.item() * input.size(0)
-
- nb_train_samples += input.size(0)
-
- loss.backward()
-
- if nb_train_samples % args.batch_size == 0:
- optimizer.step()
-
- train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
-
- log_string(f"train_perplexity {n_epoch} {train_perplexity}")
-
-
-######################################################################
-