nb_test_samples, acc_test_loss = 0, 0.0
nb_samples_accumulated = 0
- for input in quiz_machine.batches(model, split="test"):
+ full_input, _ = quiz_machine.data_input(model, split="test")
+ src = full_input.split(args.batch_size)
+
+ for input in tqdm.tqdm(src, dynamic_ncols=True, desc="test"):
input = input.to(local_device)
bs = model(mygpt.BracketedSequence(input))
nb_train_samples, acc_train_loss = 0, 0.0
- for input in quiz_machine.batches(model, split="train"):
+ hard_w_quizzes = []
+
+ full_input, full_from_w = quiz_machine.data_input(model, split="train")
+ src = zip(full_input.split(args.batch_size), full_from_w.split(args.batch_size))
+
+ for input, from_w in tqdm.tqdm(src, dynamic_ncols=True, desc="training"):
input = input.to(local_device)
if nb_train_samples % args.batch_size == 0:
optimizer.zero_grad()
output = model(mygpt.BracketedSequence(input)).x
- loss = F.cross_entropy(output.transpose(1, 2), input)
+ loss_per_token = F.cross_entropy(
+ output.transpose(1, 2), input, reduction="none"
+ )
+ loss = loss_per_token.mean()
acc_train_loss += loss.item() * input.size(0)
+ loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1)
+ hard_w_quizzes.append(
+ (input[from_w].to("cpu"), loss_per_samples[from_w].to("cpu"))
+ )
+
nb_train_samples += input.size(0)
loss.backward()
run_tests(model, quiz_machine, deterministic_synthesis=False)
+ threshold = torch.cat([x[1] for x in hard_w_quizzes], dim=0).sort().values
+ threshold = threshold[threshold.size(0) // 2]
+
+ model.hard_w_quizzes = torch.cat(
+ [x[0][x[1] >= threshold] for x in hard_w_quizzes], dim=0
+ )
+
model.to(main_device)
e = "???"
log_string(
- f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e} -- {(total_nb_validated * 3600)/duration:0.1f}/h)"
+ f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)"
)
validated_quizzes = torch.cat(recorded, dim=0)
######################################################################
-# Compute the entropy of the training tokens
-
-token_count = 0
-for input in quiz_machine.batches(models[0], split="train", desc="train-entropy"):
- token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum(
- (0, 1)
- )
-token_probas = token_count / token_count.sum()
-entropy = -torch.xlogy(token_probas, token_probas).sum()
-train_set_perplexity = math.exp(entropy)
-
-######################################################################
-# A bit of paranoia never hurts
-
-if args.max_percents_of_test_in_train >= 0:
-
- def subsets_as_tuples(batches, cs):
- s = set()
- for batch in batches:
- for x in batch:
- s.add(tuple([v.item() for v in x]))
- if len(s) == cs:
- yield s
- s = set()
- yield s
-
- nb_test, nb_in_train = 0, 0
- for test_subset in subsets_as_tuples(
- quiz_machine.batches(models[0], split="test", desc="test-check"), 25000
- ):
- in_train = set()
- for train_subset in subsets_as_tuples(
- quiz_machine.batches(models[0], split="train", desc="train-check"), 25000
- ):
- in_train.update(test_subset.intersection(train_subset))
- nb_in_train += len(in_train)
- nb_test += len(test_subset)
-
- log_string(
- f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
- )
-
- assert (
- nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
- ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
-
-######################################################################
-
if args.nb_new_c_quizzes_for_train is None:
args.nb_new_c_quizzes_for_train = args.nb_train_samples // 100
for model in weakest_models:
quiz_machine.renew_w_quizzes(
model=model,
- nb=args.nb_train_samples,
for_train=True,
forward_only=args.forward_only,
)
######################################################################
- def batches(self, model, split="train", desc=None):
+ def data_input(self, model, split="train"):
assert split in {"train", "test"}
with self.LOCK_C_QUIZZES:
]
w_quizzes = w_quizzes[i]
- self.nb_batch_w_quizzes = w_quizzes.size(0)
- self.nb_batch_c_quizzes = c_quizzes.size(0)
+ quizzes = torch.cat([w_quizzes, c_quizzes], dim=0)
+ from_w = torch.arange(
+ quizzes.size(0), device=quizzes.device
+ ) < w_quizzes.size(0)
+ i = torch.randperm(quizzes.size(0), device=quizzes.device)
- input = torch.cat([w_quizzes, c_quizzes], dim=0)
- else:
- input = w_quizzes
- self.nb_batch_w_quizzes = w_quizzes.size(0)
- self.nb_batch_c_quizzes = 0
-
- # Shuffle
- input = input[torch.randperm(input.size(0))]
+ return quizzes[i], type_w[i]
- if desc is None:
- desc = f"epoch-{split}"
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=desc
- ):
- yield batch
+ else:
+ return w_quizzes, torch.full(
+ (w_quizzes.size(0),), True, device=w_quizzes.device
+ )
######################################################################
######################################################################
- def renew_w_quizzes(self, model, nb, for_train=True, forward_only=False):
+ def renew_w_quizzes(self, model, for_train=True, forward_only=False):
input = model.train_w_quizzes if for_train else model.test_w_quizzes
- nb = min(nb, input.size(0))
- input[:-nb] = input[nb:].clone()
- fresh_w_quizzes = self.generate_token_sequences(nb)
- if not forward_only:
- self.reverse_random_half_in_place(fresh_w_quizzes)
- input[-nb:] = fresh_w_quizzes.to("cpu")
+
+ if for_train and hasattr(model, "hard_w_quizzes"):
+ self.logger(
+ f"re-using {model.hard_w_quizzes.size(0)} hard world quizzes from model {model.id}"
+ )
+ if model.hard_w_quizzes.size(0) >= input.size(0):
+ input[...] = model.hard_w_quizzes[
+ torch.randperm(hard_w_quizzes.size(0))[input.size(0)]
+ ]
+ else:
+ input[...] = torch.cat(
+ [
+ model.hard_w_quizzes,
+ self.generate_token_sequences(
+ input.size(0) - model.hard_w_quizzes.size(0)
+ ),
+ ],
+ dim=0,
+ )
+ else:
+ input[...] = self.generate_token_sequences(input.size(0))
######################################################################