break
return i1, j1, i2, j2
+ ######################################################################
+
def task_replace_color(self, A, f_A, B, f_B):
c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1
- for n, X, f_X in [(1, A, f_A), (3, B, f_B)]:
+ for n, X, f_X in [(1, A, f_A), (1, B, f_B)]:
for _ in range(torch.randint(n, (1,)) + 1):
i1, j1, i2, j2 = self.rec_coo(X)
X[i1:i2, j1:j2] = c1
f_X[i1:i2, j1:j2] = c2
def task_move(self, A, f_A, B, f_B):
- c = torch.randperm(len(self.colors) - 1)[:1] + 1
di, dj = torch.randint(2, (2,)) * 2 - 1
- for n, X, f_X in [(1, A, f_A), (3, B, f_B)]:
+ for n, X, f_X in [(1, A, f_A), (1, B, f_B)]:
+ c = torch.randperm(len(self.colors) - 1)[:1] + 1
for _ in range(torch.randint(n, (1,)) + 1):
while True:
i1, j1, i2, j2 = self.rec_coo(X)
f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c
def task_grow(self, A, f_A, B, f_B):
- c = torch.randperm(len(self.colors) - 1)[:1] + 1
-
- for n, X, f_X in [(1, A, f_A), (3, B, f_B)]:
+ direction = torch.randint(2, (1,))
+ for n, X, f_X in [(1, A, f_A), (1, B, f_B)]:
+ c = torch.randperm(len(self.colors) - 1)[:1] + 1
for _ in range(torch.randint(n, (1,)) + 1):
while True:
i1, j1, i2, j2 = self.rec_coo(X)
if i1 + 3 < i2 and j1 + 3 < j2:
break
- X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c
+ if direction == 0:
+ X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c
+ f_X[i1:i2, j1:j2] = c
+ else:
+ X[i1:i2, j1:j2] = c
+ f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c
+
+ def task_frame(self, A, f_A, B, f_B):
+ direction = torch.randint(2, (1,))
+ for n, X, f_X in [(1, A, f_A), (1, B, f_B)]:
+ c = torch.randperm(len(self.colors) - 1)[:1] + 1
+ for _ in range(torch.randint(n, (1,)) + 1):
+ while True:
+ i1, j1, i2, j2 = self.rec_coo(X)
+ if i1 + 3 < i2 and j1 + 3 < j2:
+ break
+ X[i1:i2, j1:j2] = c
f_X[i1:i2, j1:j2] = c
+ f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0
+
+ ######################################################################
def generate_prompts_and_answers(self, nb):
- tasks = [self.task_replace_color, self.task_move, self.task_grow]
+ tasks = [
+ self.task_replace_color,
+ self.task_move,
+ self.task_grow,
+ self.task_frame,
+ ]
prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
w = self.width
prompts, answers = lang.generate_prompts_and_answers(36)
- predicted_prompts = torch.rand(prompts.size(0)) < 0.5
- predicted_answers = torch.logical_not(predicted_prompts)
+ # predicted_prompts = torch.rand(prompts.size(0)) < 0.5
+ # predicted_answers = torch.logical_not(predicted_prompts)
lang.save_quizzes(
"/tmp",
"test",
prompts,
- answers, # predicted_prompts, predicted_answers
+ answers,
+ # You can add a bool to put a frame around the predicted parts
+ # predicted_prompts, predicted_answers
)
-
- # start_time = time.perf_counter()
- # token_sequences = lang.generate_token_sequences(nb=64)
- # delay = time.perf_counter() - start_time
- # print(f"{token_sequences.size(0)/delay:02f} seq/s")
-
- # print(lang.seq2str(seq[:4]))
-
- # for t in range(len(it[0])):
- # img = torch.cat([lang.frame2img(f[t]) for f in it], dim=0)
- # torchvision.utils.save_image(
- # img.float() / 255.0,
- # f"/tmp/frame_{t:03d}.png",
- # nrow=8,
- # padding=6,
- # pad_value=0,
- # )
-
- # m = (torch.rand(seq.size()) < 0.05).long()
- # seq = (1 - m) * seq + m * 23
-
- # print(seq.size())
- # img = lang.seq2img(token_sequences)
- # print(img.size())
-
- # torchvision.utils.save_image(
- # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0
- # )