parser.add_argument("--prompt_noise", type=float, default=0.0)
+parser.add_argument("--nb_averaging_rounds", type=int, default=1)
+
parser.add_argument("--dirty_debug", action="store_true", default=False)
parser.add_argument("--test_generator", action="store_true", default=False)
batch_size=args.inference_batch_size,
result_dir=args.result_dir,
prompt_noise=args.prompt_noise,
+ nb_averaging_rounds=args.nb_averaging_rounds,
logger=log_string,
device=main_device,
)
batch_size,
result_dir,
prompt_noise,
+ nb_averaging_rounds,
logger,
device=torch.device("cpu"),
):
self.logger = logger
self.prompt_len = None
self.answer_len = None
+
+ assert prompt_noise > 0 or nb_averaging_rounds == 1
+
self.prompt_noise = prompt_noise
+ self.nb_averaging_rounds = nb_averaging_rounds
self.understood_structures = [
(("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)),
c_quizzes = self.problem.reconfigure(c_quizzes, struct)
- if self.prompt_noise > 0.0 and noise_mask is not None:
- c_quizzes = self.problem.inject_noise(
- c_quizzes, self.prompt_noise, struct=struct, mask=noise_mask
- )
-
seq_logproba = torch.zeros(
c_quizzes.size(0),
max([m.id for m in models_for_validation]) + 1,
device=device,
)
- for model in models_for_validation:
- with torch.autograd.no_grad():
- t = model.training
- model.eval()
-
- for input, l in zip(
- c_quizzes.split(self.batch_size),
- seq_logproba.split(self.batch_size),
- ):
- input = input.to(device)
- ar_mask = self.make_ar_mask(input, struct=struct, mask=mask)
- output = model(mygpt.BracketedSequence(input)).x
- l[:, model.id] = (
- -F.cross_entropy(
- output.transpose(1, 2), input, reduction="none"
- )
- * ar_mask
- ).sum(dim=1)
-
- model.train(t)
-
- return seq_logproba.to("cpu")
+ for a in range(self.nb_averaging_rounds):
+ if self.prompt_noise > 0.0 and noise_mask is not None:
+ c_quizzes = self.problem.inject_noise(
+ c_quizzes, self.prompt_noise, struct=struct, mask=noise_mask
+ )
+
+ for model in models_for_validation:
+ with torch.autograd.no_grad():
+ t = model.training
+ model.eval()
+
+ for input, l in zip(
+ c_quizzes.split(self.batch_size),
+ seq_logproba.split(self.batch_size),
+ ):
+ input = input.to(device)
+ ar_mask = self.make_ar_mask(input, struct=struct, mask=mask)
+ output = model(mygpt.BracketedSequence(input)).x
+ l[:, model.id] += (
+ -F.cross_entropy(
+ output.transpose(1, 2), input, reduction="none"
+ )
+ * ar_mask
+ ).sum(dim=1)
+
+ model.train(t)
+
+ return seq_logproba.div(self.nb_averaging_rounds).to("cpu")
######################################################################