if science_w_quizzes is not None:
struct = ("A", "f_A", "B", "f_B")
mask = (0, 0, 0, 1)
- result, correct = quiz_machine.predict(
+ result, correct, _ = quiz_machine.predict(
model=model,
quizzes=science_w_quizzes.to(main_device),
struct=struct,
solved_c_quizzes = c_quizzes[:, None, :].expand(-1, len(models), -1).clone()
+ seq_logproba = torch.zeros(
+ c_quizzes.size(0), len(models), device=solved_c_quizzes.device
+ )
+
for m in models:
- solved_c_quizzes[:, m.id] = quiz_machine.predict(
+ (
+ solved_c_quizzes[:, m.id],
+ _,
+ seq_logproba[:, m.id],
+ ) = quiz_machine.predict(
m,
solved_c_quizzes[:, m.id],
struct=("A", "f_A", "B", "f_B"),
mask=(0, 0, 0, 1),
)
+ #!!!!!!!!!!!!!!!!!!!!
+ l = quiz_machine.models_logprobas(
+ models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ )
+ for s in range(seq_logproba.size(0)):
+ print(f"-- {s=} ----------------")
+ for m in range(seq_logproba.size(1)):
+ print("DEBUG", seq_logproba[s, m].item(), l[s, m].item())
+ exit(0)
+ #!!!!!!!!!!!!!!!!!!!!!!!!!
+
# FINISH
seq_logproba = quiz_machine.models_logprobas(
record_new_c_quizzes(
models,
quiz_machine,
- nb_errorsfor_train=args.nb_new_c_quizzes_for_train,
+ nb_for_train=args.nb_new_c_quizzes_for_train,
nb_for_test=args.nb_new_c_quizzes_for_test,
)
model,
input,
ar_mask,
- seq_logproba,
+ acc_seq_logproba,
deterministic_synthesis=False,
):
if input.size(0) == 0:
all_n = torch.arange(t_next.size(0))
- seq_logproba += logits[all_n, t_next]
+ acc_seq_logproba += ar_mask[:, s] * logits[all_n, t_next]
input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
model,
input,
ar_mask,
- seq_logproba=None,
+ seq_logproba,
progress_bar_desc=None,
):
assert input.size() == ar_mask.size()
- if seq_logproba is None:
- seq_logproba = torch.empty(input.size(0), device=self.device)
-
batches = zip(
input.split(self.batch_size),
ar_mask.split(self.batch_size),
model=model,
input=input,
ar_mask=ar_mask,
- seq_logproba=seq_logproba,
+ acc_seq_logproba=seq_logproba,
deterministic_synthesis=False,
)
######################################################################
def predict(self, model, quizzes, struct, mask):
+ quizzes = quizzes.to(self.device)
ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, mask=mask)
result = quizzes * (1 - ar_mask)
- seq_logproba = torch.empty(quizzes.size(0), device=self.device)
+ seq_logproba = torch.zeros(quizzes.size(0), device=self.device)
self.autoregression(
model=model,
correct = (result == quizzes).min(dim=1).values.long()
- return result, correct
+ result = result.to("cpu")
+ correct = correct.to("cpu")
+ seq_logproba = seq_logproba.to("cpu")
+
+ return result, correct, seq_logproba
######################################################################
for struct, mask_generate, _, _ in self.test_structures:
i = self.problem.indices_select(quizzes=input, struct=struct)
nb += i.long().sum()
- result[i], correct[i] = self.predict(
+ result[i], correct[i], _ = self.predict(
model=model, quizzes=input[i], struct=struct, mask=mask_generate
)
predicted_parts[i] = torch.tensor(mask_generate, device=self.device)[