def nb_token_values(self):
return len(self.colors)
+ # That's quite a tensorial spaghetti mess to sample
+ # non-overlapping rectangles quickly, but made the generation of
+ # 100k samples from 1h50 with a lame pure python code to 4min with
+ # this one.
def rec_coo(self, x, n, min_height=3, min_width=3):
K = 3
- N = 4000
+ N = 1000
while True:
v = (
self.task_frame,
self.task_detect,
]
- prompts = torch.zeros(
- nb, self.height, self.width * 3, dtype=torch.int64, device=self.device
- )
- answers = torch.zeros(
- nb, self.height, self.width, dtype=torch.int64, device=self.device
- )
+ prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
+ answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
w = self.width
for prompt, answer in tqdm.tqdm(
f_B = answer
task = tasks[torch.randint(len(tasks), (1,))]
task(A, f_A, B, f_B)
+
return prompts.flatten(1), answers.flatten(1)
def save_quizzes(