max_nb_cached_chunks=None,
chunk_size=None,
nb_threads=-1,
- world_tasks=None,
- science_tasks=None,
+ tasks=None,
):
self.colors = torch.tensor([c for _, c in self.named_colors])
self.cache_rec_coo = {}
- self.all_tasks = [
+ all_tasks = [
self.task_replace_color,
self.task_translate,
self.task_grow,
# self.task_islands, # TOO MESSY
]
- if world_tasks is None:
- self.world_tasks = self.all_tasks
+ if tasks is None:
+ self.all_tasks = all_tasks
else:
- self.world_tasks = [
- getattr(self, "task_" + t) for t in world_tasks.split(",")
- ]
-
- if science_tasks is not None:
- self.science_tasks = [
- getattr(self, "task_" + t) for t in science_tasks.split(",")
- ]
+ self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")]
super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
rs = (S // (2**j)) % 2
f_X[2, -j - 1] = c[2 + rs]
+ def task_science_implicit(self, A, f_A, B, f_B):
+ nb_rec = 5
+ c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ while True:
+ i1, i2 = torch.randint(self.height, (2,)).sort().values
+ if i1 >= 1 and i2 < self.height and i1 + 3 < i2:
+ break
+
+ while True:
+ j1, j2 = torch.randint(self.width, (2,)).sort().values
+ if j1 >= 1 and j2 < self.width and j1 + 3 < j2:
+ break
+
+ f_X[i1:i2, j1:j2] = c[0]
+
+ # ---------------------
+
+ while True:
+ ii1, ii2 = torch.randint(self.height, (2,)).sort().values
+ if ii1 >= i1 and ii2 <= i2 and ii1 + 1 < ii2:
+ break
+ jj = torch.randint(j1, (1,))
+ X[ii1:ii2, jj:j1] = c[1]
+ f_X[ii1:ii2, jj:j1] = c[1]
+
+ while True:
+ ii1, ii2 = torch.randint(self.height, (2,)).sort().values
+ if ii1 >= i1 and ii2 <= i2 and ii1 + 1 < ii2:
+ break
+ jj = torch.randint(self.width - j2, (1,)) + j2 + 1
+ X[ii1:ii2, j2:jj] = c[2]
+ f_X[ii1:ii2, j2:jj] = c[2]
+
+ # ---------------------
+
+ while True:
+ jj1, jj2 = torch.randint(self.width, (2,)).sort().values
+ if jj1 >= j1 and jj2 <= j2 and jj1 + 1 < jj2:
+ break
+ ii = torch.randint(i1, (1,))
+ X[ii:i1, jj1:jj2] = c[3]
+ f_X[ii:i1, jj1:jj2] = c[3]
+
+ while True:
+ jj1, jj2 = torch.randint(self.width, (2,)).sort().values
+ if jj1 >= j1 and jj2 <= j2 and jj1 + 1 < jj2:
+ break
+ ii = torch.randint(self.height - i2, (1,)) + i2 + 1
+ X[i2:ii, jj1:jj2] = c[4]
+ f_X[i2:ii, jj1:jj2] = c[4]
+
# end_tasks
######################################################################
return quizzes
- def generate_w_quizzes_(self, nb, tasks=None, science=False, progress_bar=False):
+ def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False):
S = self.height * self.width
if tasks is None:
- if science:
- tasks = self.science_tasks
- else:
- tasks = self.world_tasks
+ tasks = self.all_tasks
quizzes = self.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B"))
return quizzes
- def save_some_examples(self, result_dir, science=False):
+ def save_some_examples(self, result_dir):
nb, nrow = 128, 4
- tasks = self.science_tasks if science else self.world_tasks
- for t in tasks:
+ for t in self.all_tasks:
print(t.__name__)
quizzes = self.generate_w_quizzes_(nb, tasks=[t])
self.save_quizzes_as_image(
nb, nrow = 128, 4
# nb, nrow = 8, 2
- # for t in grids.world_tasks:
+ # for t in grids.all_tasks:
- for t in [grids.task_path]:
+ for t in [grids.task_science_implicit]:
print(t.__name__)
quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
grids.save_quizzes_as_image(
comments=[f"{t.__name__} #{k}" for k in range(quizzes.size(0))],
)
- # exit(0)
+ exit(0)
nb = 1000
help="A comma-separated subset of: " + grids_tasks + ", or None for all.",
)
-assert (
- len(
- set(args.grids_world_tasks.split(","))
- & set(args.grids_science_tasks.split(","))
- )
- == 0
-), "World and science task have to be disjoint"
-
######################################################################
parser.add_argument("--sky_height", type=int, default=6)
if args.result_dir is None:
args.result_dir = f"results_culture"
+assert not args.grids_science_tasks or (
+ len(
+ set(args.grids_world_tasks.split(","))
+ & set(args.grids_science_tasks.split(","))
+ )
+ == 0
+), "World and science tasks have to be disjoint"
+
######################################################################
default_args = {
chunk_size=100,
nb_threads=args.nb_threads,
)
- back_accuracy = False
elif args.problem == "grids":
problem = grids.Grids(
max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
chunk_size=100,
nb_threads=args.nb_threads,
- world_tasks=args.grids_world_tasks,
- science_tasks=args.grids_science_tasks,
+ tasks=args.grids_world_tasks,
)
- back_accuracy = True
+
+ if args.grids_science_tasks is None:
+ science_w_quizzes = None
+ else:
+ science_problem = grids.Grids(
+ max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
+ chunk_size=100,
+ nb_threads=args.nb_threads,
+ tasks=args.grids_science_tasks,
+ )
+ science_w_quizzes = science_problem.generate_w_quizzes(args.nb_test_samples)
else:
raise ValueError
quiz_machine = quiz_machine.QuizMachine(
problem=problem,
- back_accuracy=back_accuracy,
batch_size=args.physical_batch_size,
result_dir=args.result_dir,
logger=log_string,
######################################################################
-science_w_quizzes = quiz_machine.problem.generate_w_quizzes(
- args.nb_test_samples, science=True
-)
-
-######################################################################
-
current_epoch = 0
if args.resume:
c_quizzes,
)
+ ######################################################################
+
+ if science_w_quizzes is not None:
+ result, correct = quiz_machine.predict(
+ model=model,
+ quizzes=science_w_quizzes.to(main_device),
+ struct=("A", "f_A", "B", "f_B"),
+ mask=(0, 0, 0, 1),
+ )
+
+ nb_correct = (correct == 1).long().sum()
+ nb_total = (correct != 0).long().sum()
+ log_string(
+ f"science_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}"
+ )
+
+ ######################################################################
+
# Renew the training samples
for model in weakest_models: