("gray", [128, 128, 128]),
]
- def make_ar_mask(self, quizzes, shape="fwd_3_bck_123"):
+ def check_structure(self, quizzes, struct):
S = self.height * self.width
- assert (
- (
- (quizzes[:, 0] == self.token_forward)
- | (quizzes[:, 0] == self.token_backward)
- )
- & (quizzes[:, 0] == quizzes[:, 1 * (S + 1)])
- & (quizzes[:, 0] == quizzes[:, 2 * (S + 1)])
- & (quizzes[:, 0] == quizzes[:, 3 * (S + 1)])
+ return (
+ (quizzes[:, 0 * (S + 1)] == self.l2tok(struct[0]))
+ & (quizzes[:, 1 * (S + 1)] == self.l2tok(struct[1]))
+ & (quizzes[:, 2 * (S + 1)] == self.l2tok(struct[2]))
+ & (quizzes[:, 3 * (S + 1)] == self.l2tok(struct[3]))
).all()
- T = torch.arange(quizzes.size(1), device=quizzes.device)
-
- if shape == "fwd_3_bck_123":
- forward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long()
- backward_mask = ((T % (S + 1) != 0) & (T >= 1 * (S + 1))).long()
- elif shape == "fwd_012_bck_0":
- forward_mask = ((T % (S + 1) != 0) & (T < 3 * (S + 1))).long()
- backward_mask = ((T % (S + 1) != 0) & (T < 1 * (S + 1))).long()
- elif shape == "fwd_3_bck_3":
- forward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long()
- backward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long()
- else:
- raise ValueError(shape)
-
- is_forward = (quizzes[:, 0] == self.token_forward).long()
+ def make_ar_mask(self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)):
+ assert check_structure(quizzes, struct)
- return (
- is_forward[:, None] * forward_mask[None, :]
- + (1 - is_forward)[:, None] * backward_mask[None, :]
- )
+ ar_mask = quizzes.new_zeros(quizzes.size())
- def p_a_flip(self, quizzes, pairwise_flip=False):
- S = self.height * self.width
+ a = ar_mask.reshape(-1, 4, -1)[:, :, 1:]
+ a[:, 0, :] = mask[0]
+ a[:, 1, :] = mask[1]
+ a[:, 2, :] = mask[2]
+ a[:, 3, :] = mask[3]
- assert (
- (
- (quizzes[:, 0] == self.token_forward)
- | (quizzes[:, 0] == self.token_backward)
- )
- & (quizzes[:, 0] == quizzes[:, 1 * (S + 1)])
- & (quizzes[:, 0] == quizzes[:, 2 * (S + 1)])
- & (quizzes[:, 0] == quizzes[:, 3 * (S + 1)])
- ).all()
+ return ar_mask
- if pairwise_flip:
- flipped = torch.cat(
- [
- quizzes[:, 1 * (S + 1) : 1 * (S + 1) + S + 1],
- quizzes[:, 0 * (S + 1) : 0 * (S + 1) + S + 1],
- quizzes[:, 3 * (S + 1) : 3 * (S + 1) + S + 1],
- quizzes[:, 2 * (S + 1) : 2 * (S + 1) + S + 1],
- ],
- dim=1,
- )
- else:
- flipped_from_forward = torch.cat(
- [quizzes[:, 3 * (S + 1) :], quizzes[:, : 3 * (S + 1)]],
- dim=1,
- )
- flipped_from_forward[:, torch.arange(4) * (S + 1)] = self.token_backward
+ def reconfigure(
+ self,
+ quizzes,
+ struct_from=("A", "f_A", "B", "f_B"),
+ struct_to=("f_B", "A", "f_A", "B"),
+ ):
+ assert check_structure(quizzes, struct_from)
- flipped_from_backward = torch.cat(
- [quizzes[:, S + 1 :], quizzes[:, : S + 1]], dim=1
- )
- flipped_from_backward[:, torch.arange(4) * (S + 1)] = self.token_forward
+ sf = dict((l, n) for n, l in enumerate(struct_from))
- m = (quizzes[:, 0] == self.token_forward).long()[:, None]
+ result = quizzes.new(quizzes.size())
+ q = quizzes.reshape(-1, 4, 4 * (S + 1))
+ r = reshape.reshape(-1, 4, 4 * (S + 1))
- flipped = m * flipped_from_forward + (1 - m) * flipped_from_backward
+ r[:, 0, :] = q[:, sf[struct_to[0]]]
+ r[:, 1, :] = q[:, sf[struct_to[1]]]
+ r[:, 2, :] = q[:, sf[struct_to[2]]]
+ r[:, 3, :] = q[:, sf[struct_to[3]]]
- return flipped
+ return result
def __init__(
self,
tasks=None,
):
self.colors = torch.tensor([c for _, c in self.named_colors])
- self.token_forward = len(self.colors)
- self.token_backward = self.token_forward + 1
+
+ self.token_A = len(self.colors)
+ self.token_f_A = self.token_A + 1
+ self.token_B = self.token_f_A + 1
+ self.token_f_B = self.token_B + 1
+ self.l2tok = {
+ "A": self.token_A,
+ "f_A": self.token_f_A,
+ "B": self.token_B,
+ "f_B": self.token_f_B,
+ }
+
+ self.nb_token_values = self.token_f_B + 1
+
self.height = 10
self.width = 10
self.cache_rec_coo = {}
######################################################################
- def frame2img(self, x, scale=15):
- x = x.reshape(x.size(0), self.height, self.width)
+ def grid2img(self, x, scale=15):
m = torch.logical_and(x >= 0, x < len(self.colors)).long()
y = self.colors[x * m].permute(0, 3, 1, 2)
s = y.shape
y[:, :, :, torch.arange(0, y.size(3), scale)] = 0
y[:, :, torch.arange(0, y.size(2), scale), :] = 0
- y = y[:, :, 1:, 1:]
for n in range(m.size(0)):
for i in range(m.size(1)):
for j in range(m.size(2)):
- if x[n, i, j] == self.token_forward:
- for k in range(2, scale - 2):
- y[
- n,
- :,
- i * scale + k,
- j * scale + scale - 5 - abs(k - scale // 2),
- ] = 0
-
- elif x[n, i, j] == self.token_backward:
- for k in range(2, scale - 2):
- y[
- n, :, i * scale + k, j * scale + 3 + abs(k - scale // 2)
- ] = 0
- # y[n, :, i * scale + k, j * scale + k - l] = 0
- # y[
- # n, :, i * scale + scale - 1 - k, j * scale + k - l
- # ] = 0
+ if m[n, i, j] == 0:
+ for k in range(3, scale - 2):
+ y[n, :, i * scale + k, j * scale + k] = 0
+ y[n, :, i * scale + k, j * scale + scale - k] = 0
+
+ y = y[:, :, 1:, 1:]
return y
- def save_image(
+ def add_frame(self, img, colors, thickness):
+ result = img.new(
+ img.size(0),
+ img.size(1),
+ img.size(2) + 2 * thickness,
+ img.size(3) + 2 * thickness,
+ )
+
+ result[...] = colors[:, :, None, None]
+ result[:, :, thickness:-thickness, thickness:-thickness] = img
+
+ return result
+
+ def save_quizzes_as_image(
self,
result_dir,
filename,
- prompts,
- answers,
- predicted_prompts=None,
- predicted_answers=None,
+ quizzes,
+ predicted_parts=None,
+ correct_parts=None,
nrow=4,
margin=8,
):
S = self.height * self.width
- As = prompts[:, 0 * (S + 1) + 1 : 0 * (S + 1) + S + 1].view(
- -1, self.height, self.width
- )
- f_As = prompts[:, 1 * (S + 1) + 1 : 1 * (S + 1) + S + 1].view(
- -1, self.height, self.width
- )
- Bs = prompts[:, 2 * (S + 1) + 1 : 2 * (S + 1) + S + 1].view(
- -1, self.height, self.width
+
+ A, f_A, B, f_B = (
+ quizzes.reshape(-1, 4, S + 1)[:, :, 1:]
+ .reshape(-1, 4, self.height, self.width)
+ .permute(1, 0, 2, 3)
)
- prompts = torch.cat([As, f_As, Bs], dim=2)
- answers = answers[:, 1 : S + 1].reshape(
- answers.size(0), self.height, self.width
+
+ black, white, gray, green, red = torch.tensor(
+ [[0, 0, 0], [255, 255, 255], [200, 200, 200], [0, 255, 0], [255, 0, 0]],
+ device=quizzes.device,
)
- if predicted_prompts is None:
- predicted_prompts = 255
+ img_A = self.add_frame(self.grid2img(A), black[None, :], thickness=1)
+ img_f_A = self.add_frame(self.grid2img(f_A), black[None, :], thickness=1)
+ img_B = self.add_frame(self.grid2img(B), black[None, :], thickness=1)
+ img_f_B = self.add_frame(self.grid2img(f_B), black[None, :], thickness=1)
- if predicted_answers is None:
- predicted_answers = 255
+ # predicted_parts Nx4
+ # correct_parts Nx4
- def add_frame(x, c, margin, bottom=False):
- if bottom:
- h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
- else:
- h, w, di, dj = (
- x.size(2) + 2 * margin,
- x.size(3) + 2 * margin,
- margin,
- margin,
+ if predicted_parts is None:
+ colors = white[None, None, :].expand(-1, 4, -1)
+ else:
+ if correct_parts is None:
+ colors = (
+ predicted_parts[:, :, None] * gray[None, None, :]
+ + (1 - predicted_parts[:, :, None]) * white[None, None, :]
)
-
- y = x.new_full((x.size(0), x.size(1), h, w), 0)
-
- if type(c) is int:
- y[...] = c
else:
- c = c.long()[:, None]
- c = (
- (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long()))
- * torch.tensor([64, 64, 64])
- + (c == 1).long() * torch.tensor([0, 255, 0])
- + (c == 0).long() * torch.tensor([255, 255, 255])
- + (c == -1).long() * torch.tensor([255, 0, 0])
- )
- y[...] = c[:, :, None, None]
-
- y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
-
- return y
-
- img_prompts = torch.cat(
- [
- add_frame(
- add_frame(self.frame2img(x), c=0, margin=1),
- c=predicted_prompts,
- margin=margin,
+ colors = (
+ predicted_parts[:, :, None]
+ * (
+ correct_parts[:, :, None] * green[None, None, :]
+ + (1 - correct_parts[:, :, None]) * red[None, None, :]
+ )
+ + (1 - predicted_parts[:, :, None]) * white[None, None, :]
)
- for x in prompts.to("cpu").split(split_size=self.width, dim=2)
- ],
- dim=3,
- )
-
- h = img_prompts.size(2)
- img_answers = add_frame(
- add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
- c=predicted_answers,
- margin=margin,
- )
-
- separator_size = 2 * margin
-
- separator = img_prompts.new_full(
- (
- img_prompts.size(0),
- img_prompts.size(1),
- img_prompts.size(2),
- separator_size,
- ),
- 255,
- )
-
- marker = img_prompts.new_full(
- (
- img_prompts.size(0),
- img_prompts.size(1),
- img_prompts.size(2),
- separator_size,
- ),
- 255,
- )
- # marker[:, :, 0] = 0
- # marker[:, :, h - 1] = 0
+ img_A = self.add_frame(img_A, colors[:, 0], thickness=6)
+ img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=6)
+ img_B = self.add_frame(img_B, colors[:, 2], thickness=6)
+ img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=6)
- for k in range(1, 2 * separator_size - 8):
- i = k - (separator_size - 4)
- j = separator_size - 5 - abs(i)
- marker[:, :, h // 2 - 1 + i, 2 + j] = 0
- marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
+ img_A = self.add_frame(img_A, white[None, :], thickness=2)
+ img_f_A = self.add_frame(img_f_A, white[None, :], thickness=2)
+ img_B = self.add_frame(img_B, white[None, :], thickness=2)
+ img_f_B = self.add_frame(img_f_B, white[None, :], thickness=2)
- img = torch.cat(
- [
- img_prompts,
- marker,
- img_answers,
- ],
- dim=3,
- )
+ img = torch.cat([img_A, img_f_A, img_B, img_f_B], dim=3)
image_name = os.path.join(result_dir, filename)
+
torchvision.utils.save_image(
img.float() / 255.0,
image_name,
######################################################################
- def nb_token_values(self):
- return len(self.colors) + 2
-
# @torch.compile
def rec_coo(
self,
f_Bs = answers[:, 1:]
return (Bs == f_Bs).long().min(dim=-1).values > 0
- def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False):
+ def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False):
if tasks is None:
tasks = self.all_tasks
S = self.height * self.width
- prompts = torch.full((nb, 3 * S + 3), self.token_forward)
- answers = torch.full((nb, S + 1), self.token_forward)
-
- bunch = zip(prompts, answers)
+ quizzes = torch.empty(nb, 4 * (S + 1), dtype=torch.int64)
if progress_bar:
- bunch = tqdm.tqdm(
- bunch,
+ quizzes = tqdm.tqdm(
+ quizzes,
dynamic_ncols=True,
- desc="world generation",
+ desc="world quizzes generation",
total=prompts.size(0),
)
- for prompt, answer in bunch:
- A = prompt[0 * (S + 1) + 1 : 0 * (S + 1) + 1 + S].view(
- self.height, self.width
- )
- f_A = prompt[1 * (S + 1) + 1 : 1 * (S + 1) + 1 + S].view(
- self.height, self.width
- )
- B = prompt[2 * (S + 1) + 1 : 2 * (S + 1) + S + 1].view(
- self.height, self.width
- )
- f_B = answer[1 : S + 1].view(self.height, self.width)
+ for quiz in quizzes:
+ q = quiz.reshape(4, S + 1)[:, 1:].reshape(4, self.height, self.width)
+ q[...] = 0
+ A, f_A, B, f_B = q
task = tasks[torch.randint(len(tasks), (1,)).item()]
- A[...] = 0
- f_A[...] = 0
- B[...] = 0
- f_B[...] = 0
task(A, f_A, B, f_B)
- return prompts.flatten(1), answers.flatten(1)
-
- def save_quiz_illustrations(
- self,
- result_dir,
- filename_prefix,
- prompts,
- answers,
- predicted_prompts=None,
- predicted_answers=None,
- nrow=4,
- ):
- self.save_image(
- result_dir,
- filename_prefix + ".png",
- prompts,
- answers,
- predicted_prompts,
- predicted_answers,
- nrow,
- )
+ return quizzes
def save_some_examples(self, result_dir):
nb, nrow = 128, 4
for t in self.all_tasks:
print(t.__name__)
- prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t])
- self.save_quiz_illustrations(
+ prompts, answers = self.generate_w_quizzes_(nb, tasks=[t])
+ self.save_quizzes_as_image(
result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow
)
# )
# time.sleep(10)
# start_time = time.perf_counter()
- # prompts, answers = grids.generate_prompts_and_answers(nb)
+ # prompts, answers = grids.generate_w_quizzes(nb)
# delay = time.perf_counter() - start_time
# print(f"{prompts.size(0)/delay:02f} seq/s")
# exit(0)
# nb, nrow = 8, 2
# for t in grids.all_tasks:
- for t in [grids.task_reconfigure]:
+ for t in [grids.task_replace_color]:
# for t in [grids.task_symbols]:
print(t.__name__)
- prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
- # prompts[...] = torch.randint(grids.nb_token_values(), prompts.size())
- grids.save_quiz_illustrations(
- "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow
+ quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
+ predicted_parts = quizzes.new_zeros(quizzes.size(0), 4)
+ predicted_parts[:, 3] = 1
+ correct_parts = torch.randint(2, (quizzes.size(0), 4), device=quizzes.device)
+ grids.save_quizzes_as_image(
+ "/tmp",
+ t.__name__ + ".png",
+ quizzes,
+ predicted_parts=predicted_parts,
+ correct_parts=correct_parts,
)
exit(0)
for t in grids.all_tasks:
# for t in [grids.task_compute]:
start_time = time.perf_counter()
- prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
+ prompts, answers = grids.generate_w_quizzes_(nb, tasks=[t])
delay = time.perf_counter() - start_time
print(f"{t.__name__} {prompts.size(0)/delay:02f} seq/s")
return c_quizzes.to("cpu")
######################################################################
-
- def generate_c_quizzes_mixing(
- self,
- nb,
- model_for_generation,
- p2a_only=False,
- temperature_hot=1.0,
- temperature_cold=1.0,
- ):
- c_quizzes = torch.empty(
- nb,
- self.prompt_len + self.answer_len,
- device=self.device,
- dtype=torch.int64,
- )
-
- c_quizzes_1 = torch.empty(
- nb,
- self.prompt_len + self.answer_len,
- device=self.device,
- dtype=torch.int64,
- )
-
- c_quizzes_2 = torch.empty(
- nb,
- self.prompt_len + self.answer_len,
- device=self.device,
- dtype=torch.int64,
- )
-
- seq_logproba = torch.zeros(nb, device=self.device)
-
- lt_noisy = lambda s, logits: logits / temperature_hot
- lt_clean = lambda s, logits: logits / temperature_cold
-
- ######################################################################
-
- c_quizzes_1[...] = self.problem.token_backward
- ar_mask = self.problem.make_ar_mask(c_quizzes_1, shape="fwd_012_bck_0")
-
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes_1,
- ar_mask=ar_mask,
- seq_logproba=seq_logproba,
- logit_transformer=lt_noisy,
- deterministic_synthesis=False,
- device=self.device,
- )
-
- self.save_quiz_illustrations("/tmp", f"c_quizzes_1", c_quizzes_1)
-
- c_quizzes_2[...] = self.problem.token_backward
-
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes_2,
- ar_mask=ar_mask,
- seq_logproba=seq_logproba,
- logit_transformer=lt_noisy,
- deterministic_synthesis=False,
- device=self.device,
- )
-
- self.save_quiz_illustrations("/tmp", f"c_quizzes_2", c_quizzes_2)
-
- h = len(model_for_generation.trunk) // 2
-
- with torch.autograd.no_grad():
- t = model_for_generation.training
- model_for_generation.eval()
-
- bs1 = model_for_generation.partial_forward(
- mygpt.BracketedSequence(c_quizzes_1), end_layer=h
- )
- bs2 = model_for_generation.partial_forward(
- mygpt.BracketedSequence(c_quizzes_2), end_layer=h
- )
-
- alpha = 0.5
-
- output = model_for_generation.partial_forward(
- mygpt.BracketedSequence(alpha * bs1.x + (1 - alpha) * bs2.x),
- start_layer=h,
- ).x
-
- dist = torch.distributions.categorical.Categorical(logits=output)
- c_quizzes[...] = dist.sample()
-
- c_quizzes[...] = (
- ar_mask * c_quizzes + (1 - ar_mask) * self.problem.token_backward
- )
-
- model_for_generation.train(t)
-
- self.save_quiz_illustrations("/tmp", f"c_quizzes", c_quizzes)
-
- ######################################################################
-
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
- seq_logproba=seq_logproba,
- logit_transformer=lt_clean,
- deterministic_synthesis=False,
- device=self.device,
- )
-
- self.save_quiz_illustrations("/tmp", f"c_quizzes_A", c_quizzes)
-
- c_quizzes = self.problem.p_a_flip(c_quizzes)
-
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
- seq_logproba=seq_logproba,
- logit_transformer=lt_clean,
- deterministic_synthesis=False,
- device=self.device,
- )
-
- self.save_quiz_illustrations("/tmp", f"c_quizzes_B", c_quizzes)
-
- print("DONE")
- exit(0)
-
- return c_quizzes.to("cpu")