model.train()
nb_train_samples, acc_train_loss = 0, 0.0
- full_input, full_mask_loss = quiz_machine.data_input(args.nb_train_samples)
+ data_structures = [
+ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
+ ]
+
+ full_input, full_mask_loss = quiz_machine.data_input(
+ args.nb_train_samples, data_structures=data_structures
+ )
src = zip(
full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
targets = input
input = (mask_loss == 0).long() * input
+
output = model(mygpt.BracketedSequence(input)).x
loss = F.cross_entropy(output.transpose(1, 2), targets)
acc_train_loss += loss.item() * input.size(0)
######################################################################
- def data_input(self, nb_samples, c_quiz_bags=[], c_quiz_multiplier=1):
+ def data_input(
+ self, nb_samples, c_quiz_bags=[], c_quiz_multiplier=1, data_structures=None
+ ):
+ if data_structures is None:
+ data_structures = self.train_structures
+
if len(c_quiz_bags) > 0:
c_quizzes = torch.cat(c_quiz_bags, dim=0)
quizzes = quizzes[i]
self.randomize_configuations_inplace(
- quizzes, structs=[s for s, _, _, _ in self.train_structures]
+ quizzes, structs=[s for s, _, _, _ in data_structures]
)
quiz_mask_loss = quizzes.new_full(quizzes.size(), 1)
- if self.prompt_noise > 0.0:
- for struct, _, quad_noise, quad_loss in self.train_structures:
- i = self.problem.indices_select(quizzes=quizzes, struct=struct)
- if i.any():
+ for struct, _, quad_noise, quad_loss in data_structures:
+ i = self.problem.indices_select(quizzes=quizzes, struct=struct)
+ if i.any():
+ if self.prompt_noise > 0.0:
quizzes[i] = self.problem.inject_noise(
quizzes[i], self.prompt_noise, struct=struct, quad=quad_noise
)
- quiz_mask_loss[i] = self.make_quiz_mask(
- quizzes=quizzes[i], struct=struct, quad=quad_loss
- )
+ quiz_mask_loss[i] = self.make_quiz_mask(
+ quizzes=quizzes[i], struct=struct, quad=quad_loss
+ )
+
+ print("quad_loss", quad_loss)
+ print("quiz_mask_loss", quiz_mask_loss)
return quizzes, quiz_mask_loss