):
quizzes = quizzes.to("cpu")
- if not self.check_structure(quizzes, ("A", "f_A", "B", "f_B")):
- print(f"**WARNING** {filename} is not in A/f_A/B/f_B order")
+ to_reconfigure = [result]
+ if predicted_parts is not None:
+ to_reconfigure.append(predicted_parts)
+ if correct_parts is not None:
+ to_reconfigure.append(correct_parts)
+
+ to_reconfigure = self.problem.reconfigure(
+ to_reconfigure, ("A", "f_A", "B", "f_B")
+ )
+
+ result = to_reconfigure.pop(0)
+ if predicted_parts is not None:
+ predicted_parts = to_reconfigure.pop(0)
+ if correct_parts is not None:
+ correct_parts = to_reconfigure.pop(0)
S = self.height * self.width
self.prompt_len = None
self.answer_len = None
- self.train_struct = [
- ("A", "f_A", "B", "f_B"), # The standard order
- ("f_A", "A", "f_B", "B"), # The reverse order for validation
- ("B", "f_B", "A", "f_A"),
- ("f_B", "B", "f_A", "A"),
- ("f_B", "f_A", "A", "B"), # The synthesis order
- ("f_B", "f_A", "A", "B"), # twice!
+ self.understood_structures = [
+ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1)),
+ (("f_A", "A", "f_B", "B"), (0, 0, 0, 1)),
+ (("B", "f_B", "A", "f_A"), (0, 0, 0, 1)),
+ (("f_B", "B", "f_A", "A"), (0, 0, 0, 1)),
+ (("f_B", "f_A", "A", "B"), (0, 1, 1, 1)),
]
self.LOCK_C_QUIZZES = threading.Lock()
quizzes = w_quizzes.clone()
from_w = torch.full((quizzes.size(0),), True, device=quizzes.device)
- self.randomize_configuations_inplace(quizzes, structs=self.train_struct)
+ self.randomize_configuations_inplace(
+ quizzes, structs=[s for s in self.understood_structures]
+ )
i = torch.randperm(quizzes.size(0), device=quizzes.device)
######################################################################
def make_ar_mask(self, quizzes, struct, mask):
- assert struct in self.train_struct
+ assert struct in [s for s in self.understood_structures]
return self.problem.make_ar_mask(quizzes, struct=struct, mask=mask)
######################################################################
nb = 0
# We consider all the configurations that we train for
- for struct, mask in [
- (("A", "f_A", "B", "f_B"), (0, 0, 0, 1)),
- (("f_A", "A", "f_B", "B"), (0, 0, 0, 1)),
- (("B", "f_B", "A", "f_A"), (0, 0, 0, 1)),
- (("f_B", "B", "f_A", "A"), (0, 0, 0, 1)),
- (("f_B", "f_A", "A", "B"), (0, 1, 1, 1)),
- ]:
+ for struct, mask in self.understood_structures:
i = self.problem.indices_select(quizzes=input, struct=struct)
nb += i.long().sum()
result[i], correct[i] = self.predict(
predicted_parts = predicted_parts[:128]
correct_parts = correct_parts[:128]
- result, predicted_parts, correct_parts = self.problem.reconfigure(
- [result, predicted_parts, correct_parts], ("A", "f_A", "B", "f_B")
- )
-
self.problem.save_quizzes_as_image(
result_dir,
f"culture_prediction_{n_epoch:04d}_{model.id:02d}.png",
)
self.randomize_configuations_inplace(
- model.train_w_quizzes, structs=self.train_struct
+ model.train_w_quizzes, structs=[s for s in self.understood_structures]
)
######################################################################