):
quizzes = quizzes.to("cpu")
- to_reconfigure = [result]
+ to_reconfigure = [quizzes]
if predicted_parts is not None:
to_reconfigure.append(predicted_parts)
if correct_parts is not None:
to_reconfigure.append(correct_parts)
- to_reconfigure = self.problem.reconfigure(
- to_reconfigure, ("A", "f_A", "B", "f_B")
- )
+ to_reconfigure = self.reconfigure(to_reconfigure, ("A", "f_A", "B", "f_B"))
- result = to_reconfigure.pop(0)
+ quizzes = to_reconfigure.pop(0)
if predicted_parts is not None:
predicted_parts = to_reconfigure.pop(0)
if correct_parts is not None:
mask=mask,
)
- predicted_parts = torch.tensor(mask, device=correct.device)[None, :]
+ predicted_parts = torch.tensor(mask, device=correct.device)[None, :].expand(
+ correct.size(0), -1
+ )
correct = (2 * correct - 1) * (predicted_parts.sum(dim=-1) == 1).long()
nb_correct = (correct == 1).long().sum()
from_w = torch.full((quizzes.size(0),), True, device=quizzes.device)
self.randomize_configuations_inplace(
- quizzes, structs=[s for s in self.understood_structures]
+ quizzes, structs=[s for s, m in self.understood_structures]
)
i = torch.randperm(quizzes.size(0), device=quizzes.device)
######################################################################
def make_ar_mask(self, quizzes, struct, mask):
- assert struct in [s for s in self.understood_structures]
+ assert struct in [s for s, m in self.understood_structures]
return self.problem.make_ar_mask(quizzes, struct=struct, mask=mask)
######################################################################
)
self.randomize_configuations_inplace(
- model.train_w_quizzes, structs=[s for s in self.understood_structures]
+ model.train_w_quizzes, structs=[s for s, m in self.understood_structures]
)
######################################################################