max_nb_cached_chunks=None,
chunk_size=None,
nb_threads=-1,
+ tasks=None,
):
self.colors = torch.tensor([c for _, c in self.named_colors])
self.height = 10
self.width = 10
self.cache_rec_coo = {}
+
+ all_tasks = [
+ self.task_replace_color,
+ self.task_translate,
+ self.task_grow,
+ self.task_half_fill,
+ self.task_frame,
+ self.task_detect,
+ self.task_count,
+ self.task_trajectory,
+ self.task_bounce,
+ self.task_scale,
+ self.task_symbols,
+ self.task_isometry,
+ # self.task_path,
+ ]
+
+ if tasks is None:
+ self.all_tasks = all_tasks
+ else:
+ self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")]
+
super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
######################################################################
f_X[i1:i2, j1:j2] = c[n]
# @torch.compile
- def task_color_grow(self, A, f_A, B, f_B):
+ def task_half_fill(self, A, f_A, B, f_B):
di, dj = torch.randint(2, (2,)) * 2 - 1
nb_rec = 3
c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
# @torch.compile
- def task_ortho(self, A, f_A, B, f_B):
+ def task_isometry(self, A, f_A, B, f_B):
nb_rec = 3
di, dj = torch.randint(3, (2,)) - 1
o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]])
######################################################################
- def all_tasks(self):
- return [
- self.task_replace_color,
- self.task_translate,
- self.task_grow,
- self.task_color_grow,
- self.task_frame,
- self.task_detect,
- self.task_count,
- self.task_trajectory,
- self.task_bounce,
- self.task_scale,
- self.task_symbols,
- self.task_ortho,
- # self.task_path,
- ]
-
def trivial_prompts_and_answers(self, prompts, answers):
S = self.height * self.width
Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S]
def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False):
if tasks is None:
- tasks = self.all_tasks()
+ tasks = self.all_tasks
S = self.height * self.width
prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64)
def save_some_examples(self, result_dir):
nb, nrow = 72, 4
- for t in self.all_tasks():
+ for t in self.all_tasks:
print(t.__name__)
prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t])
self.save_quizzes(
nb, nrow = 72, 4
# nb, nrow = 8, 2
- # for t in grids.all_tasks():
+ # for t in grids.all_tasks:
for t in [
grids.task_replace_color,
grids.task_frame,
nb = 1000
- for t in grids.all_tasks():
- # for t in [ grids.task_replace_color ]: #grids.all_tasks():
+ for t in grids.all_tasks:
+ # for t in [ grids.task_replace_color ]: #grids.all_tasks:
start_time = time.perf_counter()
prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
delay = time.perf_counter() - start_time
######################################################################
+grids_tasks = ", ".join(
+ [x.__name__.removeprefix("task_") for x in grids.Grids().all_tasks]
+)
+
+parser.add_argument(
+ "--grids_tasks",
+ type=str,
+ default=None,
+ help="A comma-separated subset of: " + grids_tasks + ", or None for all.",
+)
+
+######################################################################
+
parser.add_argument("--sky_height", type=int, default=6)
parser.add_argument("--sky_width", type=int, default=8)
max_nb_cached_chunks=args.nb_gpus * args.nb_train_samples // 100,
chunk_size=100,
nb_threads=args.nb_threads,
+ tasks=args.grids_tasks,
)
back_accuracy = True
else: