S = self.height * self.width
return (
- (quizzes[:, 0 * (S + 1)] == self.l2tok(struct[0]))
- & (quizzes[:, 1 * (S + 1)] == self.l2tok(struct[1]))
- & (quizzes[:, 2 * (S + 1)] == self.l2tok(struct[2]))
- & (quizzes[:, 3 * (S + 1)] == self.l2tok(struct[3]))
+ (quizzes[:, 0 * (S + 1)] == self.l2tok[struct[0]])
+ & (quizzes[:, 1 * (S + 1)] == self.l2tok[struct[1]])
+ & (quizzes[:, 2 * (S + 1)] == self.l2tok[struct[2]])
+ & (quizzes[:, 3 * (S + 1)] == self.l2tok[struct[3]])
).all()
+ def get_structure(self, quizzes):
+ S = self.height * self.width
+ struct = tuple(
+ self.tok2l[n.item()] for n in quizzes.reshape(-1, 4, S + 1)[0, :, 0]
+ )
+ self.check_structure(quizzes, struct)
+ return struct
+
def make_ar_mask(self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)):
assert check_structure(quizzes, struct)
return ar_mask
- def reconfigure(
- self,
- quizzes,
- struct_from=("A", "f_A", "B", "f_B"),
- struct_to=("f_B", "A", "f_A", "B"),
- ):
- assert check_structure(quizzes, struct_from)
+ def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")):
+ S = self.height * self.width
+ struct_from = self.get_structure(quizzes)
sf = dict((l, n) for n, l in enumerate(struct_from))
result = quizzes.new(quizzes.size())
- q = quizzes.reshape(-1, 4, 4 * (S + 1))
- r = reshape.reshape(-1, 4, 4 * (S + 1))
+ q = quizzes.reshape(-1, 4, S + 1)
+ r = result.reshape(-1, 4, S + 1)
- r[:, 0, :] = q[:, sf[struct_to[0]]]
- r[:, 1, :] = q[:, sf[struct_to[1]]]
- r[:, 2, :] = q[:, sf[struct_to[2]]]
- r[:, 3, :] = q[:, sf[struct_to[3]]]
+ r[:, 0] = q[:, sf[struct[0]], :]
+ r[:, 1] = q[:, sf[struct[1]], :]
+ r[:, 2] = q[:, sf[struct[2]], :]
+ r[:, 3] = q[:, sf[struct[3]], :]
return result
self.token_f_A = self.token_A + 1
self.token_B = self.token_f_A + 1
self.token_f_B = self.token_B + 1
+
self.l2tok = {
"A": self.token_A,
"f_A": self.token_f_A,
"f_B": self.token_f_B,
}
+ self.tok2l = {
+ self.token_A: "A",
+ self.token_f_A: "f_A",
+ self.token_B: "B",
+ self.token_f_B: "f_B",
+ }
+
self.nb_token_values = self.token_f_B + 1
self.height = 10
+ (1 - predicted_parts[:, :, None]) * white[None, None, :]
)
- img_A = self.add_frame(img_A, colors[:, 0], thickness=6)
- img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=6)
- img_B = self.add_frame(img_B, colors[:, 2], thickness=6)
- img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=6)
+ img_A = self.add_frame(img_A, colors[:, 0], thickness=8)
+ img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=8)
+ img_B = self.add_frame(img_B, colors[:, 2], thickness=8)
+ img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=8)
img_A = self.add_frame(img_A, white[None, :], thickness=2)
img_f_A = self.add_frame(img_f_A, white[None, :], thickness=2)
total=prompts.size(0),
)
+ quizzes[...] = 0
+ quizzes[:, 0 * (S + 1)] = self.token_A
+ quizzes[:, 1 * (S + 1)] = self.token_f_A
+ quizzes[:, 2 * (S + 1)] = self.token_B
+ quizzes[:, 3 * (S + 1)] = self.token_f_B
+
for quiz in quizzes:
q = quiz.reshape(4, S + 1)[:, 1:].reshape(4, self.height, self.width)
q[...] = 0
# grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4)
grids = Grids()
+ nb = 5
+ quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill])
+ print(grids.get_structure(quizzes))
+ blah = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B"))
+ print(grids.get_structure(blah))
+ exit(0)
+
# nb = 1000
# grids = problem.MultiThreadProblem(
# grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1
# for t in [grids.task_symbols]:
print(t.__name__)
quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
+ print(grids.get_structure(quizzes))
predicted_parts = quizzes.new_zeros(quizzes.size(0), 4)
- predicted_parts[:, 3] = 1
+ predicted_parts[:, 3] = torch.randint(
+ 2, (quizzes.size(0),), device=quizzes.device
+ )
+ predicted_parts[:, :3] = 1 - predicted_parts[:, 3:]
correct_parts = torch.randint(2, (quizzes.size(0), 4), device=quizzes.device)
+ correct_parts[:, 1:2] = correct_parts[:, :1]
grids.save_quizzes_as_image(
"/tmp",
t.__name__ + ".png",