def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
nb_to_validate = nb_for_train + nb_for_test
- nb_to_generate_per_iteration = nb_to_validate
+ nb_to_generate_per_iteration = max(args.physical_batch_size, nb_to_validate)
nb_validated = 0
recorded_validated = []
# This is nb_quizzes x nb_models
number_correct_responses = 0
+ remains = [c_quizzes.size(0)]
+
for r in range(args.nb_rounds):
number_correct_responses += quiz_machine.models_successes(models, c_quizzes)
c_quizzes = c_quizzes[to_keep]
number_correct_responses = number_correct_responses[to_keep]
- log_string(f"round {r} remains {c_quizzes.size(0)}")
+ remains.append(c_quizzes.size(0))
if c_quizzes.size(0) == 0:
break
else:
e = "???"
+ v = " ".join([x.item() for x in remains])
+ log_string(f"filter c_quizzes {v}")
+
log_string(
f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)"
)
if vq.size(0) > 0:
prefix = f"culture_c_quiz_{n_epoch:04d}"
+
+ number_correct_responses = 0
+ for r in range(args.nb_rounds):
+ number_correct_responses += quiz_machine.models_successes(models, vq)
+
+ with open(os.path.join(args.result_dir, prefix + "_responses.dat"), "w") as f:
+ for n, r in enumerate(number_correct_responses):
+ v = " ".join([str(n.item()) for n in r])
+ f.write(f"{n}: {v}\n")
+
quiz_machine.save_quiz_illustrations(
args.result_dir, prefix, vq, show_part_to_predict=False
)
return c_quizzes.to("cpu")
######################################################################
+
+ def generate_c_quizzes_mixing(
+ self,
+ nb,
+ model_for_generation,
+ p2a_only=False,
+ temperature_hot=1.0,
+ temperature_cold=1.0,
+ ):
+ c_quizzes = torch.empty(
+ nb,
+ self.prompt_len + self.answer_len,
+ device=self.device,
+ dtype=torch.int64,
+ )
+
+ c_quizzes_1 = torch.empty(
+ nb,
+ self.prompt_len + self.answer_len,
+ device=self.device,
+ dtype=torch.int64,
+ )
+
+ c_quizzes_2 = torch.empty(
+ nb,
+ self.prompt_len + self.answer_len,
+ device=self.device,
+ dtype=torch.int64,
+ )
+
+ seq_logproba = torch.zeros(nb, device=self.device)
+
+ lt_noisy = lambda s, logits: logits / temperature_hot
+ lt_clean = lambda s, logits: logits / temperature_cold
+
+ ######################################################################
+
+ c_quizzes_1[...] = self.problem.token_backward
+ ar_mask = self.problem.make_ar_mask(c_quizzes_1, shape="fwd_012_bck_0")
+
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes_1,
+ ar_mask=ar_mask,
+ seq_logproba=seq_logproba,
+ logit_transformer=lt_noisy,
+ deterministic_synthesis=False,
+ device=self.device,
+ )
+
+ self.save_quiz_illustrations("/tmp", f"c_quizzes_1", c_quizzes_1)
+
+ c_quizzes_2[...] = self.problem.token_backward
+
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes_2,
+ ar_mask=ar_mask,
+ seq_logproba=seq_logproba,
+ logit_transformer=lt_noisy,
+ deterministic_synthesis=False,
+ device=self.device,
+ )
+
+ self.save_quiz_illustrations("/tmp", f"c_quizzes_2", c_quizzes_2)
+
+ h = len(model_for_generation.trunk) // 2
+
+ with torch.autograd.no_grad():
+ t = model_for_generation.training
+ model_for_generation.eval()
+
+ bs1 = model_for_generation.partial_forward(
+ mygpt.BracketedSequence(c_quizzes_1), end_layer=h
+ )
+ bs2 = model_for_generation.partial_forward(
+ mygpt.BracketedSequence(c_quizzes_2), end_layer=h
+ )
+
+ alpha = 0.5
+
+ output = model_for_generation.partial_forward(
+ mygpt.BracketedSequence(alpha * bs1.x + (1 - alpha) * bs2.x),
+ start_layer=h,
+ ).x
+
+ dist = torch.distributions.categorical.Categorical(logits=output)
+ c_quizzes[...] = dist.sample()
+
+ c_quizzes[...] = (
+ ar_mask * c_quizzes + (1 - ar_mask) * self.problem.token_backward
+ )
+
+ model_for_generation.train(t)
+
+ self.save_quiz_illustrations("/tmp", f"c_quizzes", c_quizzes)
+
+ ######################################################################
+
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
+ seq_logproba=seq_logproba,
+ logit_transformer=lt_clean,
+ deterministic_synthesis=False,
+ device=self.device,
+ )
+
+ self.save_quiz_illustrations("/tmp", f"c_quizzes_A", c_quizzes)
+
+ c_quizzes = self.problem.p_a_flip(c_quizzes)
+
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
+ seq_logproba=seq_logproba,
+ logit_transformer=lt_clean,
+ deterministic_synthesis=False,
+ device=self.device,
+ )
+
+ self.save_quiz_illustrations("/tmp", f"c_quizzes_B", c_quizzes)
+
+ print("DONE")
+ exit(0)
+
+ return c_quizzes.to("cpu")