class Grids(problem.Problem):
- grid_gray = 64
- thickness = 1
- background_gray = 255
+ # grid_gray = 64
+ # thickness = 1
+ # background_gray = 255
+ # dots = False
# grid_gray=240
# thickness=1
# background_gray=240
+ # dots = False
- # grid_gray = 255
- # thickness = 0
- # background_gray = 240
+ grid_gray = 200
+ thickness = 0
+ background_gray = 240
+ dots = True
named_colors = [
("white", [background_gray, background_gray, background_gray]),
("gray", [128, 128, 128]),
]
- def check_order(self, quizzes, quad_order):
- S = self.height * self.width
-
- return (
- (quizzes[:, 0 * (S + 1)] == self.l2tok[quad_order[0]])
- & (quizzes[:, 1 * (S + 1)] == self.l2tok[quad_order[1]])
- & (quizzes[:, 2 * (S + 1)] == self.l2tok[quad_order[2]])
- & (quizzes[:, 3 * (S + 1)] == self.l2tok[quad_order[3]])
- ).all()
-
- def get_order(self, quizzes):
- S = self.height * self.width
- quad_order = tuple(
- self.tok2l[n.item()]
- for n in quizzes.reshape(quizzes.size(0), 4, S + 1)[0, :, 0]
- )
- self.check_order(quizzes, quad_order)
- return quad_order
-
- def inject_noise(self, quizzes, noise, quad_order, quad_noise):
- assert self.check_order(quizzes, quad_order=quad_order)
- S = self.height * self.width
-
- mask = torch.tensor(quad_noise, device=quizzes.device)
- mask = mask[None, :, None].expand(1, 4, S + 1).clone()
- mask[:, :, 0] = 0
- mask = mask.reshape(1, -1).expand_as(quizzes)
- mask = mask * (torch.rand(mask.size(), device=mask.device) <= noise).long()
- random = torch.randint(self.nb_colors, mask.size())
- quizzes = mask * random + (1 - mask) * quizzes
-
- return quizzes
-
def pure_noise(self, nb, device):
result = torch.randint(
- self.nb_colors, (nb, 4 * (self.height * self.height + 1)), device=device
+ self.nb_colors, (nb, 4 * (self.height * self.height)), device=device
)
- result.view(nb, 4, -1)[:, 0, 0] = self.token_A
- result.view(nb, 4, -1)[:, 1, 0] = self.token_f_A
- result.view(nb, 4, -1)[:, 2, 0] = self.token_B
- result.view(nb, 4, -1)[:, 3, 0] = self.token_f_B
- return result
-
- # What a mess
- def reconfigure(self, quizzes, quad_order=("A", "f_A", "B", "f_B")):
- if torch.is_tensor(quizzes):
- return self.reconfigure([quizzes], quad_order=quad_order)[0]
-
- S = self.height * self.width
- result = [x.new(x.size()) for x in quizzes]
-
- quad_order_from = self.get_order(quizzes[0][:1])
- i = self.indices_select(quizzes[0], quad_order_from)
-
- sf = dict((l, n) for n, l in enumerate(quad_order_from))
-
- for q in range(4):
- k = sf[quad_order[q]]
- for x, y in zip(quizzes, result):
- l = x.size(1) // 4
- y[i, q * l : (q + 1) * l] = x[i, k * l : (k + 1) * l]
-
- j = i == False
-
- if j.any():
- for z, y in zip(
- self.reconfigure([x[j] for x in quizzes], quad_order=quad_order), result
- ):
- y[j] = z
-
return result
def trivial(self, quizzes):
dim=1
).values
- def make_quiz_mask(
- self, quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=(0, 0, 0, 1)
- ):
- assert self.check_order(quizzes, quad_order)
-
- ar_mask = quizzes.new_zeros(quizzes.size())
-
- S = self.height * self.width
- a = ar_mask.reshape(ar_mask.size(0), 4, S + 1)[:, :, 1:]
- a[:, 0, :] = quad_mask[0]
- a[:, 1, :] = quad_mask[1]
- a[:, 2, :] = quad_mask[2]
- a[:, 3, :] = quad_mask[3]
-
- return ar_mask
-
def text2quiz(self, t):
chr2col = [
(".", "white"),
self.colors = torch.tensor([c for _, c in self.named_colors])
self.nb_colors = len(self.colors)
- self.token_A = self.nb_colors
- self.token_f_A = self.token_A + 1
- self.token_B = self.token_f_A + 1
- self.token_f_B = self.token_B + 1
self.nb_rec_max = 5
self.rfree = torch.tensor([])
- self.l2tok = {
- "A": self.token_A,
- "f_A": self.token_f_A,
- "B": self.token_B,
- "f_B": self.token_f_B,
- }
-
- self.tok2l = {
- self.token_A: "A",
- self.token_f_A: "f_A",
- self.token_B: "B",
- self.token_f_B: "f_B",
- }
-
self.height = 10
self.width = 10
- self.seq_len = 4 * (1 + self.height * self.width)
- self.nb_token_values = self.token_f_B + 1
+ self.seq_len = 4 * self.height * self.width
self.cache_rec_coo = {}
######################################################################
def vocabulary_size(self):
- return self.nb_token_values
+ return self.nb_colors
def grid2img(self, x, scale=15, grids=True):
m = torch.logical_and(x >= 0, x < self.nb_colors).long()
for t in range(self.thickness):
y[:, :, :, torch.arange(t, y.size(3), scale)] = self.grid_gray
y[:, :, torch.arange(t, y.size(2), scale), :] = self.grid_gray
+ if self.dots:
+ z = y.reshape(
+ y.size(0),
+ y.size(1),
+ y.size(2) // scale,
+ scale,
+ y.size(3) // scale,
+ scale,
+ )
+ z = z[
+ :,
+ :,
+ :,
+ scale // 2 - 2 : scale // 2 + 1,
+ :,
+ scale // 2 - 2 : scale // 2 + 1,
+ ]
+ z[...] = (z == self.background_gray) * self.grid_gray + (
+ z != self.background_gray
+ ) * z
for n in range(m.size(0)):
for i in range(m.size(1)):
for j in range(m.size(2)):
- if m[n, i, j] == 0:
- for k in range(3, scale - 2):
- y[n, :, i * scale + k, j * scale + k] = 0
- y[n, :, i * scale + k, j * scale + scale - k] = 0
+ if x[n, i, j] >= self.nb_colors:
+ # for k in range(3, scale - 2):
+ c = self.colors[x[n, i, j] - self.nb_colors][:, None, None]
+ # y[n, :, i * scale + k, j * scale + k] = c
+ # y[n, :, i * scale + k, j * scale + scale - k] = c
+ y[
+ n,
+ :,
+ i * scale + 3 : i * scale + scale - 2,
+ j * scale + 3 : j * scale + scale - 2,
+ ] = c
y = y[:, :, 1:, 1:]
):
quizzes = quizzes.to("cpu")
- if quizzes.size(1) == 4 * self.height * self.width:
- quizzes = torch.cat(
- [
- quizzes.new_zeros(quizzes.size(0), 4, 1),
- quizzes.reshape(quizzes.size(0), 4, -1),
- ],
- dim=2,
- )
- quizzes[:, :, 0] = torch.tensor(
- [self.token_A, self.token_f_A, self.token_B, self.token_f_B]
- )[None, :]
- quizzes = quizzes.reshape(quizzes.size(0), -1)
-
- to_reconfigure = [quizzes]
- if predicted_parts is not None:
- to_reconfigure.append(predicted_parts)
- if correct_parts is not None:
- to_reconfigure.append(correct_parts)
-
- to_reconfigure = self.reconfigure(to_reconfigure, ("A", "f_A", "B", "f_B"))
-
- quizzes = to_reconfigure.pop(0)
- if predicted_parts is not None:
- predicted_parts = to_reconfigure.pop(0)
- if correct_parts is not None:
- correct_parts = to_reconfigure.pop(0)
-
S = self.height * self.width
A, f_A, B, f_B = (
- quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:]
+ quizzes.reshape(quizzes.size(0), 4, S)
.reshape(quizzes.size(0), 4, self.height, self.width)
.permute(1, 0, 2, 3)
)
if tasks is None:
tasks = self.all_tasks
- quizzes = self.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B"))
+ quizzes = torch.empty(nb, 4 * self.height * self.width, dtype=torch.int64)
if progress_bar:
quizzes = tqdm.tqdm(
)
for quiz in quizzes:
- q = quiz.reshape(4, S + 1)[:, 1:].reshape(4, self.height, self.width)
+ q = quiz.reshape(4, self.height, self.width)
q[...] = 0
A, f_A, B, f_B = q
task = tasks[torch.randint(len(tasks), (1,)).item()]
]:
print(t.__name__)
w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
+
+ w_quizzes[:5] = torch.randint(grids.vocabulary_size(), w_quizzes[:5].size())
+
grids.save_quizzes_as_image(
"/tmp",
t.__name__ + ".png",
######################################################################
-def pure_noise(nb, device):
- r = problem.pure_noise(nb, device)
- r = r.view(r.size(0), 4, -1)[:, :, 1:].reshape(r.size(0), -1)
- return r
-
-
def quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1):
if c_quizzes is None:
quizzes = problem.generate_w_quizzes(nb_samples)
- quizzes = quizzes.view(quizzes.size(0), 4, -1)[:, :, 1:].reshape(
- quizzes.size(0), -1
- )
nb_w_quizzes = quizzes.size(0)
nb_c_quizzes = 0
else:
"""Replace every component of the input by a random value with
probability args.proba_prompt_noise."""
input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
- noise = pure_noise(input.size(0), input.device)
+ noise = problem.pure_noise(input.size(0), input.device)
change = (1 - masks) * (
torch.rand(input.size(), device=input.device) < args.proba_prompt_noise
).long()
proba_erased = 1 - (1 - args.diffusion_proba_corruption) ** t
mask_erased = (r <= proba_erased[:, None]).long()
- noise = pure_noise(nb, input.device)
+ noise = problem.pure_noise(nb, input.device)
targets = input
input = (1 - mask_erased) * input + mask_erased * noise
masks = input.new_full(input.size(), 1)
# mini-batches second so that we keep only the samples that have
# not stabilized
- all_input = pure_noise(nb, local_device)
+ all_input = problem.pure_noise(nb, local_device)
all_masks = all_input.new_full(all_input.size(), 1)
all_changed = torch.full((all_input.size(0),), True, device=all_input.device)