nb_test_samples, acc_test_loss = 0, 0.0
nb_samples_accumulated = 0
- full_input, _ = quiz_machine.data_input(model, split="test")
- src = full_input.split(args.batch_size)
+ full_input, full_mask_loss = quiz_machine.data_input(model, split="test")
+ src = zip(
+ full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
+ )
- for input in tqdm.tqdm(src, dynamic_ncols=True, desc="test"):
+ for input, mask_loss in tqdm.tqdm(
+ src,
+ dynamic_ncols=True,
+ desc="test",
+ total=full_input.size(0) // args.batch_size,
+ ):
input = input.to(local_device)
+ mask_loss = mask_loss.to(local_device)
+ targets = input
+
output = model(mygpt.BracketedSequence(input)).x
- loss = F.cross_entropy(output.transpose(1, 2), input)
+ loss_per_token = F.cross_entropy(
+ output.transpose(1, 2), targets, reduction="none"
+ )
+ loss = (loss_per_token * mask_loss).mean()
acc_test_loss += loss.item() * input.size(0)
nb_test_samples += input.size(0)
hard_w_quizzes = []
- full_input, full_from_w = quiz_machine.data_input(model, split="train")
- src = zip(full_input.split(args.batch_size), full_from_w.split(args.batch_size))
+ full_input, full_mask_loss = quiz_machine.data_input(model, split="train")
+ src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size))
- for input, from_w in tqdm.tqdm(
+ for input, mask_loss in tqdm.tqdm(
src,
dynamic_ncols=True,
desc="training",
total=full_input.size(0) // args.batch_size,
):
input = input.to(local_device)
+ mask_loss = mask_loss.to(local_device)
if nb_train_samples % args.batch_size == 0:
model.optimizer.zero_grad()
loss_per_token = F.cross_entropy(
output.transpose(1, 2), targets, reduction="none"
)
- loss = loss_per_token.mean() + model.loss
+ loss = (loss_per_token * mask_loss).mean() + model.loss
acc_train_loss += loss.item() * input.size(0)
loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1)
- if from_w.any():
- hard_w_quizzes.append(
- (input[from_w].to("cpu"), loss_per_samples[from_w].to("cpu"))
- )
nb_train_samples += input.size(0)
self.answer_len = None
self.prompt_noise = prompt_noise
- # struct, mask_generate, mask_noise
+ # struct, mask_generate, mask_noise, mask_loss
self.understood_structures = [
- (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)),
- (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)),
- (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 0, 0)),
- (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 0, 0)),
- (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0)),
+ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)),
+ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
+ (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)),
+ (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
+ (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
+ ]
+
+ self.test_structures = [
+ self.understood_structures[0],
+ self.understood_structures[2],
+ self.understood_structures[4],
]
self.LOCK_C_QUIZZES = threading.Lock()
quizzes, from_w = quizzes[i], from_w[i]
self.randomize_configuations_inplace(
- quizzes, structs=[s for s, _, _ in self.understood_structures]
+ quizzes, structs=[s for s, _, _, _ in self.understood_structures]
)
+ quiz_mask_loss = quizzes.new_full(quizzes.size(), 1)
+
if self.prompt_noise > 0.0:
- for struct, _, mask_noise in self.understood_structures:
+ for struct, _, mask_noise, mask_loss in self.understood_structures:
i = self.problem.indices_select(quizzes=quizzes, struct=struct)
if i.any():
quizzes[i] = self.problem.inject_noise(
quizzes[i], self.prompt_noise, struct=struct, mask=mask_noise
)
+ quiz_mask_loss[i] = self.make_ar_mask(
+ quizzes=quizzes[i], struct=struct, mask=mask_loss
+ )
- return quizzes, from_w
+ return quizzes, quiz_mask_loss
######################################################################
def make_ar_mask(self, quizzes, struct, mask):
- assert struct in [s for s, _, _ in self.understood_structures]
+ assert struct in [s for s, _, _, _ in self.understood_structures]
return self.problem.make_ar_mask(quizzes, struct=struct, mask=mask)
######################################################################
nb = 0
# We consider all the configurations that we train for
- for struct, mask_generate, _ in self.understood_structures:
+ for struct, mask_generate, _, _ in self.test_structures:
i = self.problem.indices_select(quizzes=input, struct=struct)
nb += i.long().sum()
result[i], correct[i] = self.predict(