self.check_structure(quizzes, struct)
return struct
+ def inject_noise(self, quizzes, noise, struct, mask):
+ assert self.check_structure(quizzes, struct=struct)
+ S = self.height * self.width
+ mask = torch.tensor(mask, device=quizzes.device)
+ mask = mask[None, :, None].expand(1, 4, S + 1)
+ mask = mask * (torch.rand(mask.size(), device=mask.device) <= noise).long()
+ mask = mask.reshape(1, -1).expand_as(quizzes)
+ random = torch.randint(self.nb_colors, mask.size())
+
+ quizzes = mask * random + (1 - mask) * quizzes
+
+ return quizzes
+
# What a mess
def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")):
if torch.is_tensor(quizzes):
):
self.colors = torch.tensor([c for _, c in self.named_colors])
- self.token_A = len(self.colors)
+ self.nb_colors = len(self.colors)
+ self.token_A = self.nb_colors
self.token_f_A = self.token_A + 1
self.token_B = self.token_f_A + 1
self.token_f_B = self.token_B + 1
######################################################################
def grid2img(self, x, scale=15):
- m = torch.logical_and(x >= 0, x < len(self.colors)).long()
+ m = torch.logical_and(x >= 0, x < self.nb_colors).long()
y = self.colors[x * m].permute(0, 3, 1, 2)
s = y.shape
y = y[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
# @torch.compile
def task_replace_color(self, A, f_A, B, f_B):
nb_rec = 3
- c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
r = self.rec_coo(nb_rec, prevent_overlap=True)
for n in range(nb_rec):
break
nb_rec = 3
- c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+ c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
while True:
r = self.rec_coo(nb_rec, prevent_overlap=True)
def task_grow(self, A, f_A, B, f_B):
di, dj = torch.randint(2, (2,)) * 2 - 1
nb_rec = 3
- c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+ c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
direction = torch.randint(2, (1,)).item()
for X, f_X in [(A, f_A), (B, f_B)]:
while True:
def task_half_fill(self, A, f_A, B, f_B):
di, dj = torch.randint(2, (2,)) * 2 - 1
nb_rec = 3
- c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
+ c = torch.randperm(self.nb_colors - 1)[: 2 * nb_rec] + 1
direction = torch.randint(4, (1,)).item()
for X, f_X in [(A, f_A), (B, f_B)]:
r = self.rec_coo(nb_rec, prevent_overlap=True)
# @torch.compile
def task_frame(self, A, f_A, B, f_B):
nb_rec = 3
- c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
r = self.rec_coo(nb_rec, prevent_overlap=True)
for n in range(nb_rec):
# @torch.compile
def task_detect(self, A, f_A, B, f_B):
nb_rec = 3
- c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
r = self.rec_coo(nb_rec, prevent_overlap=True)
for n in range(nb_rec):
N = 3
c = torch.zeros(N + 2, dtype=torch.int64)
- c[1:] = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
+ c[1:] = torch.randperm(self.nb_colors - 1)[: N + 1] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
if not hasattr(self, "cache_count") or len(self.cache_count) == 0:
# @torch.compile
def task_trajectory(self, A, f_A, B, f_B):
- c = torch.randperm(len(self.colors) - 1)[:2] + 1
+ c = torch.randperm(self.nb_colors - 1)[:2] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
while True:
di, dj = torch.randint(7, (2,)) - 3
# @torch.compile
def task_bounce(self, A, f_A, B, f_B):
- c = torch.randperm(len(self.colors) - 1)[:3] + 1
+ c = torch.randperm(self.nb_colors - 1)[:3] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
# @torch.compile
def free(i, j):
# @torch.compile
def task_scale(self, A, f_A, B, f_B):
- c = torch.randperm(len(self.colors) - 1)[:2] + 1
+ c = torch.randperm(self.nb_colors - 1)[:2] + 1
i, j = (
torch.randint(self.height // 2, (1,)).item(),
# @torch.compile
def task_symbols(self, A, f_A, B, f_B):
nb_rec = 4
- c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
delta = 3
for X, f_X in [(A, f_A), (B, f_B)]:
while True:
X[...] = 0
f_X[...] = 0
- c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+ c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
for r in range(nb_rec):
while True:
# @torch.compile
def REMOVED_task_distance(self, A, f_A, B, f_B):
- c = torch.randperm(len(self.colors) - 1)[:3] + 1
+ c = torch.randperm(self.nb_colors - 1)[:3] + 1
dist0 = torch.empty(self.height + 2, self.width + 2)
dist1 = torch.empty(self.height + 2, self.width + 2)
for X, f_X in [(A, f_A), (B, f_B)]:
def TOO_HARD_task_puzzle(self, A, f_A, B, f_B):
S = 4
i0, j0 = (self.height - S) // 2, (self.width - S) // 2
- c = torch.randperm(len(self.colors) - 1)[:4] + 1
+ c = torch.randperm(self.nb_colors - 1)[:4] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
while True:
f_X[...] = 0
X[ii + i, jj + j] = c[d]
def TOO_MESSY_task_islands(self, A, f_A, B, f_B):
- c = torch.randperm(len(self.colors) - 1)[:2] + 1
+ c = torch.randperm(self.nb_colors - 1)[:2] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
if not hasattr(self, "cache_islands") or len(self.cache_islands) == 0:
self.cache_islands = list(
# @torch.compile
def TOO_HARD_task_stack(self, A, f_A, B, f_B):
N = 5
- c = torch.randperm(len(self.colors) - 1)[:N] + 1
+ c = torch.randperm(self.nb_colors - 1)[:N] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
i1, j1, i2, j2 = (
self.height // 2 - 1,
def TOO_HARD_task_matrices(self, A, f_A, B, f_B):
N = 6
- c = torch.randperm(len(self.colors) - 1)[:N] + 1
+ c = torch.randperm(self.nb_colors - 1)[:N] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
M1 = torch.randint(2, (5, 5))
def TOO_HARD_task_compute(self, A, f_A, B, f_B):
N = 6
- c = torch.randperm(len(self.colors) - 1)[:N] + 1
+ c = torch.randperm(self.nb_colors - 1)[:N] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
v = torch.randint((self.width - 1) // 2, (N,)) + 1
chain = torch.randperm(N)
return min(max(v, 0) + max(h + 1, 0), max(v + 1, 0) + max(h, 0))
nb_rec = 3
- c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+ c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
while True:
r = self.rec_coo(nb_rec, prevent_overlap=True)
def task_corners(self, A, f_A, B, f_B):
polarity = torch.randint(2, (1,)).item()
nb_rec = 3
- c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+ c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
r = self.rec_coo(nb_rec, prevent_overlap=True)
# @torch.compile
def task_path(self, A, f_A, B, f_B):
nb_rec = 2
- c = torch.randperm(len(self.colors) - 1)[: nb_rec + 2] + 1
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 2] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
while True:
X[...] = 0
# @torch.compile
def task_fill(self, A, f_A, B, f_B):
nb_rec = 3
- c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
accept_full = torch.rand(1) < 0.5
break
def TOO_HARD_task_addition(self, A, f_A, B, f_B):
- c = torch.randperm(len(self.colors) - 1)[:4] + 1
+ c = torch.randperm(self.nb_colors - 1)[:4] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
N1 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item()
N2 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item()
def task_science_implicit(self, A, f_A, B, f_B):
nb_rec = 5
- c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+ c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
while True:
def task_science_dot(self, A, f_A, B, f_B):
nb_rec = 3
- c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
while True:
X[...] = 0
return False
def task_science_tag(self, A, f_A, B, f_B):
- c = torch.randperm(len(self.colors) - 1)[:4] + 1
+ c = torch.randperm(self.nb_colors - 1)[:4] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
rs = []
while len(rs) < 4:
problem,
batch_size,
result_dir,
+ prompt_noise,
logger,
device=torch.device("cpu"),
):
self.logger = logger
self.prompt_len = None
self.answer_len = None
+ self.prompt_noise = prompt_noise
self.understood_structures = [
(("A", "f_A", "B", "f_B"), (0, 0, 0, 1)),
if len(c_quizzes) > 0:
c_quizzes = torch.cat(c_quizzes, dim=0)
+
if c_quizzes.size(0) > w_quizzes.size(0) // 2:
i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
c_quizzes = c_quizzes[i]
quizzes = w_quizzes.clone()
from_w = torch.full((quizzes.size(0),), True, device=quizzes.device)
- self.randomize_configuations_inplace(
- quizzes, structs=[s for s, m in self.understood_structures]
+ i = torch.randperm(quizzes.size(0), device=quizzes.device)
+ quizzes, from_w = quizzes[i], from_w[i]
+
+ if self.prompt_noise > 0.0:
+ quizzes = self.problem.inject_noise(
+ quizzes, self.prompt_noise, ("A", "f_A", "B", "f_B"), (1, 0, 1, 0)
)
- i = torch.randperm(quizzes.size(0), device=quizzes.device)
+ self.randomize_configuations_inplace(
+ quizzes, structs=[s for s, m in self.understood_structures]
+ )
- return quizzes[i], from_w[i]
+ return quizzes, from_w
######################################################################
def renew_train_w_quizzes(self, model):
if hasattr(model, "hard_w_quizzes"):
- if model.hard_w_quizzes.size(0) >= model.train_w_quizzes.size(0):
+ hard_w_quizzes = self.problem.reconfigure(
+ model.hard_w_quizzes, struct=("A", "f_A", "B", "f_B")
+ )
+ self.logger(
+ f"re-using {hard_w_quizzes.size(0)} hard world quizzes from model {model.id}"
+ )
+ if hard_w_quizzes.size(0) >= model.train_w_quizzes.size(0):
nb_to_generate = 0
- model.train_w_quizzes[...] = model.hard_w_quizzes[
+ model.train_w_quizzes[...] = hard_w_quizzes[
torch.randperm(hard_w_quizzes.size(0))[
model.train_w_quizzes.size(0)
]
]
else:
- nb_to_generate = model.train_w_quizzes.size(
- 0
- ) - model.hard_w_quizzes.size(0)
+ nb_to_generate = model.train_w_quizzes.size(0) - hard_w_quizzes.size(0)
model.train_w_quizzes[...] = torch.cat(
[
- model.hard_w_quizzes,
+ hard_w_quizzes,
self.problem.generate_w_quizzes(nb_to_generate),
],
dim=0,
model.train_w_quizzes.size(0)
)
- self.logger(
- f"re-using {model.hard_w_quizzes.size(0)} hard world quizzes from model {model.id}"
- )
-
self.randomize_configuations_inplace(
model.train_w_quizzes, structs=[s for s, m in self.understood_structures]
)