parser.add_argument("--prompt_noise", type=float, default=0.0)
-parser.add_argument("--nb_averaging_rounds", type=int, default=1)
+parser.add_argument("--nb_averaging_rounds", type=int, default=3)
parser.add_argument("--dirty_debug", action="store_true", default=False)
default_args = {
"model": "37M",
"batch_size": 25,
- "inference_batch_size": 100,
+ "inference_batch_size": 50,
"nb_train_samples": 100000,
"nb_test_samples": 10000,
}
batch_size=args.inference_batch_size,
result_dir=args.result_dir,
prompt_noise=args.prompt_noise,
- nb_averaging_rounds=args.nb_averaging_rounds,
logger=log_string,
device=main_device,
)
c_quizzes = c_quizzes[to_keep]
- # This is nb_quizzes x nb_models
+ probas = 0
- seq_logproba = quiz_machine.models_logprobas(
- models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
- ) + quiz_machine.models_logprobas(
- models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
- )
+ for a in range(args.nb_averaging_rounds):
+ # This is nb_quizzes x nb_models
- probas = seq_logproba.exp()
+ seq_logproba = quiz_machine.models_logprobas(
+ models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ ) + quiz_machine.models_logprobas(
+ models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ )
+
+ probas += seq_logproba.exp()
+
+ probas /= args.nb_averaging_rounds
nb_succeed = (probas >= args.proba_understands).long().sum(dim=1)
nb_fail = (probas <= args.proba_not_understands).long().sum(dim=1)
vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]]
if vq.size(0) > 0:
- seq_logproba = quiz_machine.models_logprobas(
- models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
- ) + quiz_machine.models_logprobas(
- models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
- )
+ probas = 0
+
+ for a in range(args.nb_averaging_rounds):
+ # This is nb_quizzes x nb_models
+
+ seq_logproba = quiz_machine.models_logprobas(
+ models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ ) + quiz_machine.models_logprobas(
+ models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ )
+
+ probas += seq_logproba.exp()
+
+ probas /= args.nb_averaging_rounds
comments = []
batch_size,
result_dir,
prompt_noise,
- nb_averaging_rounds,
logger,
device=torch.device("cpu"),
):
self.logger = logger
self.prompt_len = None
self.answer_len = None
-
- assert prompt_noise > 0 or nb_averaging_rounds == 1
-
self.prompt_noise = prompt_noise
- self.nb_averaging_rounds = nb_averaging_rounds
self.understood_structures = [
(("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)),
device=device,
)
- for a in range(self.nb_averaging_rounds):
- if self.prompt_noise > 0.0 and noise_mask is not None:
- c_quizzes = self.problem.inject_noise(
- c_quizzes, self.prompt_noise, struct=struct, mask=noise_mask
- )
+ if self.prompt_noise > 0.0 and noise_mask is not None:
+ c_quizzes = self.problem.inject_noise(
+ c_quizzes, self.prompt_noise, struct=struct, mask=noise_mask
+ )
- for model in models_for_validation:
- with torch.autograd.no_grad():
- t = model.training
- model.eval()
-
- for input, l in zip(
- c_quizzes.split(self.batch_size),
- seq_logproba.split(self.batch_size),
- ):
- input = input.to(device)
- ar_mask = self.make_ar_mask(input, struct=struct, mask=mask)
- output = model(mygpt.BracketedSequence(input)).x
- l[:, model.id] += (
- -F.cross_entropy(
- output.transpose(1, 2), input, reduction="none"
- )
- * ar_mask
- ).sum(dim=1)
-
- model.train(t)
-
- return seq_logproba.div(self.nb_averaging_rounds).to("cpu")
+ for model in models_for_validation:
+ with torch.autograd.no_grad():
+ t = model.training
+ model.eval()
+
+ for input, l in zip(
+ c_quizzes.split(self.batch_size),
+ seq_logproba.split(self.batch_size),
+ ):
+ input = input.to(device)
+ ar_mask = self.make_ar_mask(input, struct=struct, mask=mask)
+ output = model(mygpt.BracketedSequence(input)).x
+ l[:, model.id] = (
+ -F.cross_entropy(
+ output.transpose(1, 2), input, reduction="none"
+ )
+ * ar_mask
+ ).sum(dim=1)
+
+ model.train(t)
+
+ return seq_logproba.to("cpu")
######################################################################