dim=1,
)
else:
- flipped = torch.cat(
+ flipped_from_forward = torch.cat(
[
quizzes[:, 3 * (S + 1) : 3 * (S + 1) + S + 1],
- quizzes[:, 2 * (S + 1) : 2 * (S + 1) + S + 1],
+ quizzes[:, 0 * (S + 1) : 2 * (S + 1) + S + 1],
quizzes[:, 1 * (S + 1) : 1 * (S + 1) + S + 1],
+ quizzes[:, 2 * (S + 1) : 0 * (S + 1) + S + 1],
+ ],
+ dim=1,
+ )
+ flipped_from_forward[:, torch.arange(4) * (S + 1)] = self.token_backward
+
+ flipped_from_backward = torch.cat(
+ [
+ quizzes[:, 1 * (S + 1) : 3 * (S + 1) + S + 1],
+ quizzes[:, 2 * (S + 1) : 2 * (S + 1) + S + 1],
+ quizzes[:, 3 * (S + 1) : 1 * (S + 1) + S + 1],
quizzes[:, 0 * (S + 1) : 0 * (S + 1) + S + 1],
],
dim=1,
)
+ flipped_from_backward[:, torch.arange(4) * (S + 1)] = self.token_forward
+
+ m = (flipped[:, 0] == self.token_forward).long()
- m = (flipped[:, 0] == self.token_forward).long()
- flipped[:, 0 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward
- flipped[:, 1 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward
- flipped[:, 2 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward
- flipped[:, 3 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward
+ flipped = m * flipped_from_forward + (1 - m) * flipped_from_backward
return flipped
parser.add_argument("--nb_gpts", type=int, default=5)
+parser.add_argument("--max_fail_to_validate", type=int, default=1)
+
parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975)
parser.add_argument("--proba_understands", type=float, default=0.9)
parser.add_argument("--nb_rounds", type=int, default=3)
+parser.add_argument("--noise_level", type=float, default=0)
+
parser.add_argument("--c_quiz_validation_mode", type=str, default="predict")
parser.add_argument("--p2a_only", action="store_true", default=False)
if nb_train_samples % args.batch_size == 0:
optimizer.zero_grad()
+ targets = input
+
+ if args.noise_level > 0:
+ m = (
+ (torch.rand(targets.size(), device=targets.device) < args.noise_level)
+ & (targets != quiz_machine.problem.token_forward)
+ & (targets != quiz_machine.problem.token_backward)
+ ).long()
+ input = (1 - m) * input.clone() + m * torch.randint(
+ vocabulary_size, input.size(), device=input.device
+ )
+
output = model(mygpt.BracketedSequence(input)).x
loss_per_token = F.cross_entropy(
- output.transpose(1, 2), input, reduction="none"
+ output.transpose(1, 2), targets, reduction="none"
)
loss = loss_per_token.mean()
acc_train_loss += loss.item() * input.size(0)
nb_validated = 0
recorded_validated = []
- # recorded_too_simple = []
start_time = time.perf_counter()
c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
# We go through nb_rounds rounds and keep only quizzes on
- # which models respond always the same through rounds
+ # which models respond always the same through rounds and one
+ # which N-1 succeed and one fails
+
+ ms = 0 # "model scores"
- total_nb_validated = 0
- ms = 0
for r in range(args.nb_rounds):
ms += quiz_machine.models_successes(models, c_quizzes)
- # print(f"{r=} {ms=}")
- i = ((ms == r + 1).long().sum(dim=1) == ms.size(1) - 1) & (
- (ms == 0).long().sum(dim=1) == 1
+ nb_sure_and_correct = (ms == r + 1).long().sum(dim=1)
+ nb_sure_and_fail = (ms == 0).long().sum(dim=1)
+ to_keep = (
+ (nb_sure_and_correct + nb_sure_and_fail == ms.size(1))
+ & (nb_sure_and_fail >= 1)
+ & (nb_sure_and_fail <= args.max_fail_to_validate)
)
- c_quizzes = c_quizzes[i]
- ms = ms[i]
+
+ c_quizzes = c_quizzes[to_keep]
+ ms = ms[to_keep]
+ print(f"Round {r} remains {c_quizzes.size(0)}")
if c_quizzes.size(0) == 0:
break
if c_quizzes.size(0) > 0:
nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0)
- total_nb_validated = nb_validated_per_model.sum().item()
recorded_validated.append(c_quizzes)
+ total_nb_validated = nb_validated_per_model.sum().item()
+
duration = time.perf_counter() - start_time
if total_nb_validated > 0:
)
validated_quizzes = torch.cat(recorded_validated, dim=0)
- # too_simple_quizzes = torch.cat(recorded_too_simple, dim=0)
######################################################################
# store the new c_quizzes which have been validated
args.result_dir, prefix, vq, show_part_to_predict=False
)
- # vq = too_simple_quizzes[torch.randperm(too_simple_quizzes.size(0))[:128]]
-
- # if vq.size(0) > 0:
- # prefix = f"culture_c_quiz_{n_epoch:04d}_too_simple"
- # quiz_machine.save_quiz_illustrations(
- # args.result_dir, prefix, vq, show_part_to_predict=False
- # )
-
######################################################################
)
log_string(f"wrote {filename}")
+ for model in weakest_models:
+ c_quizzes = quiz_machine.generate_c_quizzes(
+ 128,
+ model_for_generation=model,
+ p2a_only=args.p2a_only,
+ temperature_hot=args.temperature_hot,
+ temperature_cold=args.temperature_cold,
+ )
+
+ quiz_machine.save_quiz_illustrations(
+ args.result_dir, f"non_validated_{n_epoch:04d}_{model.id:02d}", c_quizzes
+ )
+
# Renew the training samples
for model in weakest_models:
bs = self.readout(bs)
return bs
+ def partial_forward(self, bs, start_layer=None, end_layer=None):
+ if start_layer is None:
+ # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
+ bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
+ bs = self.embedding(bs)
+ if end_layer is not None:
+ return self.trunk[:end_layer](bs)
+ else:
+ bs = self.trunk(bs)
+ bs = self.readout(bs)
+ return bs
+ else:
+ bs = self.trunk[start_layer:](bs)
+ bs = self.trunk(bs)
+ bs = self.readout(bs)
+ return bs
+
def record_attention(self, v=True):
for m in self.modules():
if isinstance(m, QKVAttention):
)
return c_quizzes.to("cpu")
+
+ ######################################################################
+
+ def generate_c_quizzes_fixed_point(
+ self,
+ nb,
+ model_for_generation,
+ p2a_only=False,
+ temperature_hot=1.0,
+ temperature_cold=1.0,
+ ):
+ c_quizzes = torch.empty(
+ nb,
+ self.prompt_len + self.answer_len,
+ device=self.device,
+ dtype=torch.int64,
+ )
+
+ seq_logproba = torch.zeros(nb, device=self.device)
+
+ lt_noisy = lambda s, logits: logits / temperature_hot
+ lt_clean = lambda s, logits: logits / temperature_cold
+
+ c_quizzes[...] = self.problem.token_backward
+
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
+ seq_logproba=seq_logproba,
+ logit_transformer=lt_noisy,
+ deterministic_synthesis=False,
+ device=self.device,
+ )
+
+ self.save_quiz_illustrations("/tmp", f"c_quizzes_before", c_quizzes)
+
+ c_quizzes = self.problem.p_a_flip(c_quizzes)
+
+ while True:
+ print("ITERATION")
+
+ c_quizzes = self.problem.p_a_flip(c_quizzes)
+
+ pred = c_quizzes.clone()
+
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
+ seq_logproba=seq_logproba,
+ logit_transformer=lt_clean,
+ deterministic_synthesis=False,
+ device=self.device,
+ )
+
+ c_quizzes = self.problem.p_a_flip(c_quizzes)
+
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
+ seq_logproba=seq_logproba,
+ logit_transformer=lt_clean,
+ deterministic_synthesis=False,
+ device=self.device,
+ )
+
+ if pred[202:].equal(c_quizzes[202:]):
+ break
+
+ self.save_quiz_illustrations("/tmp", f"c_quizzes_after", c_quizzes)
+
+ exit(0)
+
+ return c_quizzes.to("cpu")
+
+ ######################################################################
+
+ def generate_c_quizzes_mixing(
+ self,
+ nb,
+ model_for_generation,
+ p2a_only=False,
+ temperature_hot=1.0,
+ temperature_cold=1.0,
+ ):
+ c_quizzes = torch.empty(
+ nb,
+ self.prompt_len + self.answer_len,
+ device=self.device,
+ dtype=torch.int64,
+ )
+
+ c_quizzes_1 = torch.empty(
+ nb,
+ self.prompt_len + self.answer_len,
+ device=self.device,
+ dtype=torch.int64,
+ )
+
+ c_quizzes_2 = torch.empty(
+ nb,
+ self.prompt_len + self.answer_len,
+ device=self.device,
+ dtype=torch.int64,
+ )
+
+ seq_logproba = torch.zeros(nb, device=self.device)
+
+ lt_noisy = lambda s, logits: logits / temperature_hot
+ lt_clean = lambda s, logits: logits / temperature_cold
+
+ ######################################################################
+
+ c_quizzes_1[...] = self.problem.token_backward
+ ar_mask = self.problem.make_ar_mask(c_quizzes_1, shape="fwd_012_bck_0")
+
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes_1,
+ ar_mask=ar_mask,
+ seq_logproba=seq_logproba,
+ logit_transformer=lt_noisy,
+ deterministic_synthesis=False,
+ device=self.device,
+ )
+
+ self.save_quiz_illustrations("/tmp", f"c_quizzes_1", c_quizzes_1)
+
+ c_quizzes_2[...] = self.problem.token_backward
+
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes_2,
+ ar_mask=ar_mask,
+ seq_logproba=seq_logproba,
+ logit_transformer=lt_noisy,
+ deterministic_synthesis=False,
+ device=self.device,
+ )
+
+ self.save_quiz_illustrations("/tmp", f"c_quizzes_2", c_quizzes_2)
+
+ h = len(model_for_generation.trunk) // 2
+
+ with torch.autograd.no_grad():
+ t = model_for_generation.training
+ model_for_generation.eval()
+
+ bs1 = model_for_generation.partial_forward(
+ mygpt.BracketedSequence(c_quizzes_1), end_layer=h
+ )
+ bs2 = model_for_generation.partial_forward(
+ mygpt.BracketedSequence(c_quizzes_2), end_layer=h
+ )
+
+ alpha = 0.1
+
+ output = model_for_generation.partial_forward(
+ mygpt.BracketedSequence(alpha * bs1.x + (1 - alpha) * bs2.x),
+ start_layer=h,
+ ).x
+
+ dist = torch.distributions.categorical.Categorical(logits=output)
+ c_quizzes[...] = dist.sample()
+
+ c_quizzes[...] = (
+ ar_mask * c_quizzes + (1 - ar_mask) * self.problem.token_backward
+ )
+
+ model_for_generation.train(t)
+
+ self.save_quiz_illustrations("/tmp", f"c_quizzes", c_quizzes)
+
+ ######################################################################
+
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
+ seq_logproba=seq_logproba,
+ logit_transformer=lt_clean,
+ deterministic_synthesis=False,
+ device=self.device,
+ )
+
+ self.save_quiz_illustrations("/tmp", f"c_quizzes_A", c_quizzes)
+
+ c_quizzes = self.problem.p_a_flip(c_quizzes)
+
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
+ seq_logproba=seq_logproba,
+ logit_transformer=lt_clean,
+ deterministic_synthesis=False,
+ device=self.device,
+ )
+
+ self.save_quiz_illustrations("/tmp", f"c_quizzes_B", c_quizzes)
+
+ print("DONE")
+ exit(0)
+
+ return c_quizzes.to("cpu")