self.prompt_noise = prompt_noise
self.understood_structures = [
- (("A", "f_A", "B", "f_B"), (0, 0, 0, 1)),
- (("f_A", "A", "f_B", "B"), (0, 0, 0, 1)),
- (("B", "f_B", "A", "f_A"), (0, 0, 0, 1)),
- (("f_B", "B", "f_A", "A"), (0, 0, 0, 1)),
- (("f_B", "f_A", "A", "B"), (0, 1, 1, 1)),
+ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)),
+ (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)),
+ (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 0, 0)),
+ (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 0, 0)),
+ (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0)),
]
self.LOCK_C_QUIZZES = threading.Lock()
quizzes, from_w = quizzes[i], from_w[i]
self.randomize_configuations_inplace(
- quizzes, structs=[s for s, m in self.understood_structures]
+ quizzes, structs=[s for s, m, _ in self.understood_structures]
)
if self.prompt_noise > 0.0:
- for struct, mask in self.understood_structures:
+ for struct, mask, noise_mask in self.understood_structures:
i = self.problem.indices_select(quizzes=quizzes, struct=struct)
if i.any():
quizzes[i] = self.problem.inject_noise(
- quizzes[i],
- self.prompt_noise,
- struct=struct,
- mask=tuple(1 - k for k in mask),
+ quizzes[i], self.prompt_noise, struct=struct, mask=noise_mask
)
return quizzes, from_w
######################################################################
def make_ar_mask(self, quizzes, struct, mask):
- assert struct in [s for s, m in self.understood_structures]
+ assert struct in [s for s, _, _ in self.understood_structures]
return self.problem.make_ar_mask(quizzes, struct=struct, mask=mask)
######################################################################
nb = 0
# We consider all the configurations that we train for
- for struct, mask in self.understood_structures:
+ for struct, mask, noise_mask in self.understood_structures:
i = self.problem.indices_select(quizzes=input, struct=struct)
nb += i.long().sum()
result[i], correct[i] = self.predict(