def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")):
S = self.height * self.width
+ result = quizzes.new(quizzes.size())
+
+ struct_from = self.get_structure(quizzes[:1])
+ i = self.indices_select(quizzes, struct_from)
- struct_from = self.get_structure(quizzes)
sf = dict((l, n) for n, l in enumerate(struct_from))
- result = quizzes.new(quizzes.size())
- q = quizzes.reshape(quizzes.size(0), 4, S + 1)
- r = result.reshape(result.size(0), 4, S + 1)
+ q = quizzes.reshape(-1, 4, S + 1)[i]
+
+ result[i, 0 * (S + 1) : 1 * (S + 1)] = q[:, sf[struct[0]], :]
+ result[i, 1 * (S + 1) : 2 * (S + 1)] = q[:, sf[struct[1]], :]
+ result[i, 2 * (S + 1) : 3 * (S + 1)] = q[:, sf[struct[2]], :]
+ result[i, 3 * (S + 1) : 4 * (S + 1)] = q[:, sf[struct[3]], :]
+
+ j = i == False
- r[:, 0] = q[:, sf[struct[0]], :]
- r[:, 1] = q[:, sf[struct[1]], :]
- r[:, 2] = q[:, sf[struct[2]], :]
- r[:, 3] = q[:, sf[struct[3]], :]
+ if j.any():
+ result[j] = self.reconfigure(quizzes[j], struct=struct)
return result
y = y[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
y = y.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
- y[:, :, :, torch.arange(0, y.size(3), scale)] = 0
- y[:, :, torch.arange(0, y.size(2), scale), :] = 0
+ y[:, :, :, torch.arange(0, y.size(3), scale)] = 224
+ y[:, :, torch.arange(0, y.size(2), scale), :] = 224
for n in range(m.size(0)):
for i in range(m.size(1)):
nb = 5
quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill])
+ print(quizzes)
print(grids.get_structure(quizzes))
quizzes = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B"))
+ print("DEBUG2", quizzes)
print(grids.get_structure(quizzes))
+ print(quizzes)
i = torch.rand(quizzes.size(0)) < 0.5
self.prompt_len = None
self.answer_len = None
+ self.configurations = [
+ ("A", "f_A", "B", "f_B"), # The standard order
+ ("f_A", "A", "f_B", "B"), # The reverse order for validation
+ ("f_B", "f_A", "A", "B"), # The synthesis order
+ ]
+
self.LOCK_C_QUIZZES = threading.Lock()
self.train_c_quizzes = []
self.test_c_quizzes = []
nb = 0
for struct, mask in [
(("A", "f_A", "B", "f_B"), (0, 0, 0, 1)),
- (("f_B", "f_A", "B", "A"), (0, 1, 1, 1)),
+ (("f_A", "A", "f_B", "B"), (0, 0, 0, 1)),
+ (("f_B", "f_A", "A", "B"), (0, 1, 1, 1)),
]:
i = self.problem.indices_select(quizzes=input, struct=struct)
nb += i.long().sum()
model=model, quizzes=input[i], struct=struct, mask=mask
)
+ print(f"{nb=} {input.size(0)=}")
assert nb == input.size(0)
main_test_accuracy = correct.sum() / correct.size(0)
######################################################################
- def flip_half_in_place(self, quizzes):
- r = torch.rand(quizzes.size(0), device=quizzes.device) < 0.5
- i = self.problem.indices_select(
- quizzes=quizzes, struct=("A", "f_A", "B", "f_B")
- )
- quizzes[i & r] = self.problem.reconfigure(
- quizzes[i & r], struct=("f_B", "f_A", "B", "A")
- )
- j = self.problem.indices_select(
- quizzes=quizzes, struct=("f_B", "f_A", "B", "A")
- )
- quizzes[j & r] = self.problem.reconfigure(
- quizzes[j & r], struct=("A", "f_A", "B", "f_B")
+ def randomize_configuations_inplace(self, quizzes, configurations):
+ r = torch.randint(
+ len(configurations), (quizzes.size(0),), device=quizzes.device
)
+ for c in range(len(configurations)):
+ quizzes[r == c] = self.problem.reconfigure(
+ quizzes[r == c], struct=configurations[c]
+ )
+
def create_w_quizzes(self, model, nb_train_samples, nb_test_samples):
model.train_w_quizzes = self.problem.generate_w_quizzes(nb_train_samples)
model.test_w_quizzes = self.problem.generate_w_quizzes(nb_test_samples)
- self.flip_half_in_place(model.train_w_quizzes)
- self.flip_half_in_place(model.test_w_quizzes)
+ self.randomize_configuations_inplace(
+ model.train_w_quizzes, configurations=self.configurations
+ )
+
+ self.randomize_configuations_inplace(
+ model.test_w_quizzes, configurations=self.configurations
+ )
+
+ # print(model.train_w_quizzes.sum())
######################################################################
model.train_w_quizzes.size(0)
)
- self.flip_half_in_place(model.train_w_quizzes)
+ self.randomize_configuations_inplace(
+ model.train_w_quizzes, configurations=self.configurations
+ )
######################################################################