def get_structure(self, quizzes):
S = self.height * self.width
struct = tuple(
- self.tok2l[n.item()] for n in quizzes.reshape(-1, 4, S + 1)[0, :, 0]
+ self.tok2l[n.item()]
+ for n in quizzes.reshape(quizzes.size(0), 4, S + 1)[0, :, 0]
)
self.check_structure(quizzes, struct)
return struct
sf = dict((l, n) for n, l in enumerate(struct_from))
result = quizzes.new(quizzes.size())
- q = quizzes.reshape(-1, 4, S + 1)
- r = result.reshape(-1, 4, S + 1)
+ q = quizzes.reshape(quizzes.size(0), 4, S + 1)
+ r = result.reshape(result.size(0), 4, S + 1)
r[:, 0] = q[:, sf[struct[0]], :]
r[:, 1] = q[:, sf[struct[1]], :]
return result
+ def non_trivial(self, quizzes):
+ S = self.height * self.width
+ assert self.check_structure(quizzes, struct=("A", "f_A", "B", "f_B"))
+ a = quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:]
+ return (a[:, 0] == a[:, 1]).min(dim=1).values & (a[:, 2] == a[:, 3]).min(
+ dim=1
+ ).values
+
def make_ar_mask(self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)):
- assert check_structure(quizzes, struct)
+ assert self.check_structure(quizzes, struct)
ar_mask = quizzes.new_zeros(quizzes.size())
- a = ar_mask.reshape(-1, 4, -1)[:, :, 1:]
+ S = self.height * self.width
+ a = ar_mask.reshape(ar_mask.size(0), 4, S + 1)[:, :, 1:]
a[:, 0, :] = mask[0]
a[:, 1, :] = mask[1]
a[:, 2, :] = mask[2]
def indices_select(self, quizzes, struct=("A", "f_A", "B", "f_B")):
S = self.height * self.width
- q = quizzes.reshape(-1, 4, S + 1)
+ q = quizzes.reshape(quizzes.size(0), 4, S + 1)
return (
(q[:, 0, 0] == self.l2tok[struct[0]])
& (q[:, 1, 0] == self.l2tok[struct[1]])
nrow=4,
margin=8,
):
+ quizzes = quizzes.to("cpu")
+
S = self.height * self.width
A, f_A, B, f_B = (
- quizzes.reshape(-1, 4, S + 1)[:, :, 1:]
- .reshape(-1, 4, self.height, self.width)
+ quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:]
+ .reshape(quizzes.size(0), 4, self.height, self.width)
.permute(1, 0, 2, 3)
)
def create_empty_quizzes(self, nb, struct=("A", "f_A", "B", "f_B")):
S = self.height * self.width
quizzes = torch.zeros(nb, 4 * (S + 1), dtype=torch.int64)
- quizzes[:, 0 * (S + 1)] = self.l2tok(struct[0])
- quizzes[:, 1 * (S + 1)] = self.l2tok(struct[1])
- quizzes[:, 2 * (S + 1)] = self.l2tok(struct[2])
- quizzes[:, 3 * (S + 1)] = self.l2tok(struct[3])
+ quizzes[:, 0 * (S + 1)] = self.l2tok[struct[0]]
+ quizzes[:, 1 * (S + 1)] = self.l2tok[struct[1]]
+ quizzes[:, 2 * (S + 1)] = self.l2tok[struct[2]]
+ quizzes[:, 3 * (S + 1)] = self.l2tok[struct[3]]
return quizzes
def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False):
+ S = self.height * self.width
+
if tasks is None:
tasks = self.all_tasks
######################################################################
-def run_tests(model, quiz_machine, deterministic_synthesis, local_device=main_device):
+def run_tests(model, quiz_machine, local_device=main_device):
with torch.autograd.no_grad():
model.eval().to(local_device)
model=model,
input=full_input[:2000],
result_dir=args.result_dir,
- deterministic_synthesis=deterministic_synthesis,
)
log_string(f"train_perplexity {n_epoch} model {model.id} {train_perplexity}")
- run_tests(model, quiz_machine, deterministic_synthesis=False)
+ run_tests(model, quiz_machine)
threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values
threshold = threshold[threshold.size(0) // 2]
# We discard the trivial ones, according to a criterion
# specific to the world quizzes (e.g. B=f(B))
- c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
+ c_quizzes = c_quizzes[quiz_machine.problem.non_trivial(c_quizzes)]
# We go through nb_rounds rounds and keep only quizzes on
# which
remains = [c_quizzes.size(0)]
for r in range(args.nb_rounds):
+ if c_quizzes.size(0) == 0:
+ break
+
number_correct_responses += quiz_machine.models_successes(models, c_quizzes)
nb_sure_correct = (number_correct_responses == r + 1).long().sum(dim=1)
remains.append(c_quizzes.size(0))
- if c_quizzes.size(0) == 0:
- break
-
if c_quizzes.size(0) > 0:
nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0)
recorded_validated.append(c_quizzes)
v = " ".join([str(n.item()) for n in r])
f.write(f"{n}: {v}\n")
- quiz_machine.save_quiz_illustrations(
- args.result_dir, prefix, vq, show_part_to_predict=False
- )
+ quiz_machine.save_quizzes_as_image(args.result_dir, prefix, vq)
######################################################################
temperature_cold=args.temperature_cold,
)
- quiz_machine.save_quiz_illustrations(
- args.result_dir, f"non_validated_{n_epoch:04d}_{model.id:02d}", c_quizzes
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir,
+ f"non_validated_{n_epoch:04d}_{model.id:02d}.png",
+ c_quizzes,
)
# Renew the training samples
logit_transformer=None,
deterministic_synthesis=False,
):
+ if input.size(0) == 0:
+ return
+
to_generate = (ar_mask.sum(0) > 0).nonzero()
if to_generate.min() > 0:
######################################################################
- def predict(self, input, struct, mask):
+ def predict(self, model, quizzes, struct, mask):
ar_mask = self.problem.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask)
result = quizzes * (1 - ar_mask)
- seq_logproba = torch.empty(fwd_quizzes, device=self.device)
+ seq_logproba = torch.empty(quizzes.size(0), device=self.device)
masked_inplace_autoregression(
model=model,
input=result,
ar_mask=ar_mask,
seq_logproba=seq_logproba,
- deterministic_synthesis=deterministic_synthesis,
+ deterministic_synthesis=False,
progress_bar_desc="accuracy",
device=self.device,
)
- nb_correct = (result == quizzes).min(dim=1).long()
+ correct = (result == quizzes).min(dim=1).values
return result, correct
def produce_results(
- self, n_epoch, model, input, result_dir, deterministic_synthesis
+ self,
+ n_epoch,
+ model,
+ input,
+ result_dir,
):
input = input.to(self.device)
- i = self.problem.indices_select(quizzes=input, struct=struct)
-
- input_fwd = input[i]
- test_result_fwd, test_correct_fwd = predict(
- input_fwd, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
- )
-
- input_bck = self.problem.reconfigure(
- predict(input[i == False], ("f_B", "f_A", "B", "A"), (0, 1, 1, 1))[0],
- struct=("A", "f_A", "B", "f_B"),
- )
-
- l = input_bck.size(1) // 4
- input_bck[:, 3 * l :] = input[i == False][:, :l]
+ result = input.new(input.size())
+ correct = torch.empty(input.size(0), device=input.device, dtype=torch.bool)
+
+ nb = 0
+ for struct, mask in [
+ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1)),
+ (("f_B", "f_A", "B", "A"), (0, 1, 1, 1)),
+ ]:
+ i = self.problem.indices_select(quizzes=input, struct=struct)
+ nb += i.long().sum()
+ result[i], correct[i] = self.predict(
+ model=model, quizzes=input[i], struct=struct, mask=mask
+ )
- test_result_bck, test_correct_bck = predict(
- input_bck, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
- )
+ assert nb == input.size(0)
- main_test_accuracy = test_correct.sum() / test_correct.size(0)
+ main_test_accuracy = correct.sum() / correct.size(0)
##############################
- test_result = torch.cat([test_result_fwd[:64], test_result_bck[:64]], dim=0)
- test_correct = torch.cat([test_correct_fwd[:64], test_correct_bck[:64]], dim=0)
-
- self.save_quiz_illustrations(
+ self.problem.save_quizzes_as_image(
result_dir,
- f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
- quizzes=test_result,
- # mistakes=test_correct,
+ f"culture_prediction_{n_epoch:04d}_{model.id:02d}.png",
+ quizzes=result[:128],
)
return main_test_accuracy
seq_logproba[...] = 0.0
+ c_quizzes = c_quizzes.to(self.device)
+ print(self.problem.get_structure(c_quizzes))
+ reversed_c_quizzes = self.problem.reconfigure(
+ c_quizzes, ("f_A", "A", "f_B", "B")
+ )
+
for model in models_for_validation:
# A, f(A), B | f(B)
- c_quizzes = c_quizzes.to(self.device)
result = c_quizzes.clone()
- ar_mask = self.problem.make_ar_mask(result, shape="fwd_3_bck_3")
+ ar_mask = self.problem.make_ar_mask(
+ result, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)
+ )
masked_inplace_autoregression(
model=model,
# -------------------------------
# f(A), A, f(B) | B
- c_quizzes = self.problem.flip(c_quizzes, pairwise_flip=True).to(self.device)
- result = c_quizzes.clone()
+ result = reversed_c_quizzes.clone()
- ar_mask = self.problem.make_ar_mask(result, shape="fwd_3_bck_3")
+ ar_mask = self.problem.make_ar_mask(
+ result, ("f_A", "A", "f_B", "B"), mask=(0, 0, 0, 1)
+ )
masked_inplace_autoregression(
model=model,
device=self.device,
)
- correct *= (c_quizzes == result).long().min(dim=-1).values
+ correct *= (reversed_c_quizzes == result).long().min(dim=-1).values
# -------------------------------
temperature_hot=1.0,
temperature_cold=1.0,
):
- c_quizzes = torch.empty(
- nb,
- self.problem.seq_len,
- device=self.device,
- dtype=torch.int64,
+ c_quizzes = self.problem.create_empty_quizzes(nb, ("f_B", "f_A", "A", "B")).to(
+ self.device
)
seq_logproba = torch.zeros(nb, device=self.device)
# )
# lt_clean = None
- c_quizzes[...] = self.problem.token_backward
-
masked_inplace_autoregression(
model=model_for_generation,
batch_size=self.batch_size,
input=c_quizzes,
- ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
+ ar_mask=self.problem.make_ar_mask(
+ c_quizzes, ("f_B", "f_A", "A", "B"), (1, 0, 0, 0)
+ ),
seq_logproba=seq_logproba,
logit_transformer=lt_noisy,
deterministic_synthesis=False,
model=model_for_generation,
batch_size=self.batch_size,
input=c_quizzes,
- ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
+ ar_mask=self.problem.make_ar_mask(
+ c_quizzes, ("f_B", "f_A", "A", "B"), (0, 1, 1, 1)
+ ),
seq_logproba=seq_logproba,
logit_transformer=lt_clean,
deterministic_synthesis=False,
device=self.device,
)
- c_quizzes = self.problem.p_a_flip(c_quizzes)
+ c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B"))
masked_inplace_autoregression(
model=model_for_generation,
batch_size=self.batch_size,
input=c_quizzes,
- ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
+ ar_mask=self.problem.make_ar_mask(
+ c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+ ),
seq_logproba=seq_logproba,
logit_transformer=lt_clean,
deterministic_synthesis=False,