("gray", [128, 128, 128]),
]
+ def make_ar_mask(self, quizzes, first=False):
+ S = self.height * self.width
+
+ assert (
+ (
+ (quizzes[:, 0] == self.token_forward)
+ | (quizzes[:, 0] == self.token_backward)
+ )
+ & (quizzes[:, 0] == quizzes[:, 1 * (S + 1)])
+ & (quizzes[:, 0] == quizzes[:, 2 * (S + 1)])
+ & (quizzes[:, 0] == quizzes[:, 3 * (S + 1)])
+ ).all()
+
+ T = torch.arange(quizzes.size(1), device=quizzes.device)
+
+ if first:
+ forward_mask = ((T % (S + 1) != 0) & (T < 3 * (S + 1))).long()
+ backward_mask = ((T % (S + 1) != 0) & (T < S + 1)).long()
+ else:
+ forward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long()
+ backward_mask = ((T % (S + 1) != 0) & (T >= S + 1)).long()
+
+ is_forward = (quizzes[:, 0] == self.token_forward).long()
+
+ return (
+ is_forward[:, None] * forward_mask[None, :]
+ + (1 - is_forward)[:, None] * backward_mask[None, :]
+ )
+
+ def p_a_flip(self, quizzes):
+ S = self.height * self.width
+
+ assert (
+ (
+ (quizzes[:, 0] == self.token_forward)
+ | (quizzes[:, 0] == self.token_backward)
+ )
+ & (quizzes[:, 0] == quizzes[:, 1 * (S + 1)])
+ & (quizzes[:, 0] == quizzes[:, 2 * (S + 1)])
+ & (quizzes[:, 0] == quizzes[:, 3 * (S + 1)])
+ ).all()
+
+ flipped = torch.cat(
+ [quizzes[:, k * (S + 1) : (k + 1) * (S + 1)] for k in range(3, -1, -1)],
+ dim=1,
+ )
+
+ m = (flipped[:, 0] == self.token_forward).long()
+ flipped[:, 0 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward
+ flipped[:, 1 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward
+ flipped[:, 2 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward
+ flipped[:, 3 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward
+
+ return flipped
+
def __init__(
self,
max_nb_cached_chunks=None,
######################################################################
def frame2img(self, x, scale=15):
- x = x.reshape(x.size(0), self.height, -1)
+ x = x.reshape(x.size(0), self.height, self.width)
m = torch.logical_and(x >= 0, x < len(self.colors)).long()
- x = self.colors[x * m].permute(0, 3, 1, 2)
- s = x.shape
- x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
- x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
+ y = self.colors[x * m].permute(0, 3, 1, 2)
+ s = y.shape
+ y = y[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
+ y = y.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
- x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
- x[:, :, torch.arange(0, x.size(2), scale), :] = 0
- x = x[:, :, 1:, 1:]
+ y[:, :, :, torch.arange(0, y.size(3), scale)] = 0
+ y[:, :, torch.arange(0, y.size(2), scale), :] = 0
+ y = y[:, :, 1:, 1:]
for n in range(m.size(0)):
for i in range(m.size(1)):
for j in range(m.size(2)):
- if m[n, i, j] == 0:
+ if x[n, i, j] == self.token_forward:
+ for k in range(2, scale - 2):
+ y[
+ n,
+ :,
+ i * scale + k,
+ j * scale + scale - 5 - abs(k - scale // 2),
+ ] = 0
+
+ elif x[n, i, j] == self.token_backward:
for k in range(2, scale - 2):
- for l in [0, 1]:
- x[n, :, i * scale + k, j * scale + k - l] = 0
- x[
- n, :, i * scale + scale - 1 - k, j * scale + k - l
- ] = 0
+ y[
+ n, :, i * scale + k, j * scale + 3 + abs(k - scale // 2)
+ ] = 0
+ # y[n, :, i * scale + k, j * scale + k - l] = 0
+ # y[
+ # n, :, i * scale + scale - 1 - k, j * scale + k - l
+ # ] = 0
- return x
+ return y
def save_image(
self,
)
f_B = answer[1 : S + 1].view(self.height, self.width)
task = tasks[torch.randint(len(tasks), (1,)).item()]
+ A[...] = 0
+ f_A[...] = 0
+ B[...] = 0
+ f_B[...] = 0
task(A, f_A, B, f_B)
return prompts.flatten(1), answers.flatten(1)
# exit(0)
# if True:
- nb, nrow = 128, 4
+ nb, nrow = 8, 2
# nb, nrow = 8, 2
for t in grids.all_tasks:
# for t in [grids.task_compute]:
print(t.__name__)
prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
+ # prompts[...] = torch.randint(grids.nb_token_values(), prompts.size())
grids.save_quiz_illustrations(
"/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow
)
- # exit(0)
+ exit(0)
nb = 1000
- # for t in grids.all_tasks:
- for t in [grids.task_compute]:
+ for t in grids.all_tasks:
+ # for t in [grids.task_compute]:
start_time = time.perf_counter()
prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
delay = time.perf_counter() - start_time
# Any copyright is dedicated to the Public Domain.
# https://creativecommons.org/publicdomain/zero/1.0/
+# > A > f(A) > B ; > f(B)
+# < f(B) ; < B < f(A) < A
+
# Written by Francois Fleuret <francois@fleuret.org>
import math, sys, argparse, time, tqdm, os, datetime, warnings
v_train = validated_quizzes[:nb_for_train]
quiz_machine.store_c_quizzes(v_train, for_train=True)
- quiz_machine.store_c_quizzes(quiz_machine.p_a_flip(v_train), for_train=True)
+ quiz_machine.store_c_quizzes(quiz_machine.problem.p_a_flip(v_train), for_train=True)
v_test = validated_quizzes[nb_for_train:nb_to_create]
quiz_machine.store_c_quizzes(v_test, for_train=False)
- quiz_machine.store_c_quizzes(quiz_machine.p_a_flip(v_test), for_train=False)
+ quiz_machine.store_c_quizzes(quiz_machine.problem.p_a_flip(v_test), for_train=False)
######################################################################
# save images
n_p2a = quizzes[quizzes[:, 0] == self.problem.token_forward]
n_a2p = quizzes[:, 0] == self.problem.token_backward
a2p = quizzes[n_a2p]
- quizzes[n_a2p] = self.p_a_flip(quizzes[n_a2p])
+ quizzes[n_a2p] = self.problem.p_a_flip(quizzes[n_a2p])
return torch.logical_not(
self.problem.trivial_prompts_and_answers(
quizzes[:, : self.prompt_len], quizzes[:, self.prompt_len :]
)
)
- def p_a_flip(self, quizzes):
- i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes)
-
- p2a_to_a2p = torch.cat(
- [quizzes[:, self.prompt_len :], quizzes[:, : self.prompt_len]],
- dim=1,
- )
-
- p2a_to_a2p[:, 0] = self.problem.token_backward
- p2a_to_a2p[:, self.answer_len] = self.problem.token_backward
-
- a2p_to_p2a = torch.cat(
- [quizzes[:, self.answer_len :], quizzes[:, : self.answer_len]],
- dim=1,
- )
-
- a2p_to_p2a[:, 0] = self.problem.token_forward
- a2p_to_p2a[:, self.prompt_len] = self.problem.token_forward
-
- m = i_p2a.long()[:, None]
-
- return m * p2a_to_a2p + (1 - m) * a2p_to_p2a
-
def p_a_flip_half_in_place(self, quizzes):
i = torch.rand(quizzes.size(0)) < 0.5
if i.any():
- quizzes[i] = self.p_a_flip(quizzes[i])
-
- def make_ar_mask(self, quizzes, first=False):
- i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes)
-
- t = torch.arange(quizzes.size(1), device=quizzes.device)
-
- if first:
- m_p2a = (t >= 1).long() * (t < self.prompt_len).long()
- m_a2p = (t >= 1).long() * (t < self.answer_len).long()
- else:
- m_p2a = (t >= 1 + self.prompt_len).long()
- m_a2p = (t >= 1 + self.answer_len).long()
-
- m = i_p2a.long()[:, None]
-
- return m * m_p2a + (1 - m) * m_a2p
+ quizzes[i] = self.problem.p_a_flip(quizzes[i])
def generate_token_sequences(self, nb):
prompts, answers = self.problem.generate_prompts_and_answers(nb)
n_a2p = quizzes[:, 0] == self.problem.token_backward
a2p = quizzes[n_a2p]
assert n_p2a.size(0) + a2p.size(0) == quizzes.size(0)
- quizzes[n_a2p] = self.p_a_flip(quizzes[n_a2p])
+ quizzes[n_a2p] = self.problem.p_a_flip(quizzes[n_a2p])
if show_part_to_predict:
predicted_prompts = n_a2p.long()
def produce_results(self, n_epoch, model, result_dir, deterministic_synthesis):
def compute_accuracy(input, log_prefix=None):
input = input.to(self.device)
- ar_mask = self.make_ar_mask(input)
+ ar_mask = self.problem.make_ar_mask(input)
result = input.clone() * (1 - ar_mask)
seq_logproba = torch.empty(input.size(0), device=self.device)
if self.back_accuracy and n_a2p.any():
# accuracy of B->A*->B*=B instead of B->A*=A
- back_input = self.p_a_flip(result[n_a2p])
+ back_input = self.problem.p_a_flip(result[n_a2p])
back_input[:, 1 + self.prompt_len :] = input[n_a2p, 1 : self.answer_len]
_, correct[n_a2p] = compute_accuracy(back_input)
c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
):
input = input.to(self.device)
- ar_mask = self.make_ar_mask(input)
+ ar_mask = self.problem.make_ar_mask(input)
output = model(mygpt.BracketedSequence(input)).x
l[:, model.id] = (
-F.cross_entropy(
c_quizzes = c_quizzes.to(self.device)
result = c_quizzes.clone()
- ar_mask = self.make_ar_mask(result)
+ ar_mask = self.problem.make_ar_mask(result)
masked_inplace_autoregression(
model=model,
seq_logproba = torch.zeros(nb, device=self.device)
if p2a_only:
- c_quizzes[:, 0] = self.problem.token_forward
- c_quizzes[:, self.prompt_len] = self.problem.token_forward
+ c_quizzes[...] = self.problem.token_forward
masked_inplace_autoregression(
model=model_for_generation,
batch_size=self.batch_size,
input=c_quizzes,
- ar_mask=self.make_ar_mask(c_quizzes, first=True),
+ ar_mask=self.problem.make_ar_mask(c_quizzes, first=True),
seq_logproba=seq_logproba,
temperature=temperature_hot,
deterministic_synthesis=False,
model=model_for_generation,
batch_size=self.batch_size,
input=c_quizzes,
- ar_mask=self.make_ar_mask(c_quizzes),
+ ar_mask=self.problem.make_ar_mask(c_quizzes),
seq_logproba=seq_logproba,
temperature=temperature_cold,
deterministic_synthesis=False,
)
else:
- c_quizzes[:, 0] = self.problem.token_backward
- c_quizzes[:, self.answer_len] = self.problem.token_backward
+ c_quizzes[...] = self.problem.token_backward
masked_inplace_autoregression(
model=model_for_generation,
batch_size=self.batch_size,
input=c_quizzes,
- ar_mask=self.make_ar_mask(c_quizzes, first=True),
+ ar_mask=self.problem.make_ar_mask(c_quizzes, first=True),
seq_logproba=seq_logproba,
temperature=temperature_hot,
deterministic_synthesis=False,
model=model_for_generation,
batch_size=self.batch_size,
input=c_quizzes,
- ar_mask=self.make_ar_mask(c_quizzes),
+ ar_mask=self.problem.make_ar_mask(c_quizzes),
seq_logproba=seq_logproba,
temperature=temperature_cold,
deterministic_synthesis=False,
device=self.device,
)
- c_quizzes = self.p_a_flip(c_quizzes)
+ c_quizzes = self.problem.p_a_flip(c_quizzes)
masked_inplace_autoregression(
model=model_for_generation,
batch_size=self.batch_size,
input=c_quizzes,
- ar_mask=self.make_ar_mask(c_quizzes),
+ ar_mask=self.problem.make_ar_mask(c_quizzes),
seq_logproba=seq_logproba,
temperature=temperature_cold,
deterministic_synthesis=False,