self.token_f_B: "f_B",
}
- self.nb_token_values = self.token_f_B + 1
-
self.height = 10
self.width = 10
+ self.seq_len = 4 * (1 + self.height * self.width)
+ self.nb_token_values = self.token_f_B + 1
+
self.cache_rec_coo = {}
all_tasks = [
######################################################################
+ def create_empty_quizzes(self, nb, struct=("A", "f_A", "B", "f_B")):
+ S = self.height * self.width
+ quizzes = torch.zeros(nb, 4 * (S + 1), dtype=torch.int64)
+ quizzes[:, 0 * (S + 1)] = self.l2tok(struct[0])
+ quizzes[:, 1 * (S + 1)] = self.l2tok(struct[1])
+ quizzes[:, 2 * (S + 1)] = self.l2tok(struct[2])
+ quizzes[:, 3 * (S + 1)] = self.l2tok(struct[3])
+
+ return quizzes
+
def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False):
if tasks is None:
tasks = self.all_tasks
- S = self.height * self.width
- quizzes = torch.empty(nb, 4 * (S + 1), dtype=torch.int64)
+ quizzes = self.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B"))
if progress_bar:
quizzes = tqdm.tqdm(
quizzes,
dynamic_ncols=True,
desc="world quizzes generation",
- total=prompts.size(0),
+ total=quizzes.size(0),
)
- quizzes[...] = 0
- quizzes[:, 0 * (S + 1)] = self.token_A
- quizzes[:, 1 * (S + 1)] = self.token_f_A
- quizzes[:, 2 * (S + 1)] = self.token_B
- quizzes[:, 3 * (S + 1)] = self.token_f_B
-
for quiz in quizzes:
q = quiz.reshape(4, S + 1)[:, 1:].reshape(4, self.height, self.width)
q[...] = 0
nb, nrow = 128, 4
for t in self.all_tasks:
print(t.__name__)
- prompts, answers = self.generate_w_quizzes_(nb, tasks=[t])
+ quizzes = self.generate_w_quizzes_(nb, tasks=[t])
self.save_quizzes_as_image(
- result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow
+ result_dir, t.__name__ + ".png", quizzes, nrow=nrow
)
predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
- grids.save_quiz_illustrations(
+ grids.save_quizzes_as_image(
"/tmp",
- "test",
+ "test.png",
prompts[:nb],
answers[:nb],
# You can add a bool to put a frame around the predicted parts
parser.add_argument("--c_quiz_validation_mode", type=str, default="predict")
-parser.add_argument("--p2a_only", action="store_true", default=False)
-
parser.add_argument("--dirty_debug", action="store_true", default=False)
######################################################################
acc_train_loss += loss.item() * input.size(0)
loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1)
- n_p2a = input[:, 0] == quiz_machine.problem.token_forward
- to_store = from_w & n_p2a.to("cpu")
- if to_store.any():
+ if from_w.any():
hard_w_quizzes.append(
- (input[to_store].to("cpu"), loss_per_samples[to_store].to("cpu"))
+ (input[from_w].to("cpu"), loss_per_samples[from_w].to("cpu"))
)
nb_train_samples += input.size(0)
c_quizzes = quiz_machine.generate_c_quizzes(
nb_to_generate_per_iteration,
model_for_generation=model_for_generation,
- p2a_only=args.p2a_only,
temperature_hot=args.temperature_hot,
temperature_cold=args.temperature_cold,
)
model=model,
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- p2a_only=args.p2a_only,
)
models.append(model)
c_quizzes = quiz_machine.generate_c_quizzes(
128,
model_for_generation=model,
- p2a_only=args.p2a_only,
temperature_hot=args.temperature_hot,
temperature_cold=args.temperature_cold,
)
# Renew the training samples
for model in weakest_models:
- quiz_machine.renew_train_w_quizzes(model=model, p2a_only=args.p2a_only)
+ quiz_machine.renew_train_w_quizzes(model=model)
if args.log_command is not None:
s = args.log_command.split()
else:
return self.queue.qsize() * self.chunk_size
- def nb_token_values(self):
- pass
-
- def trivial_prompts_and_answers(self, prompts, answers):
- pass
-
- # The one to implement, returns two tensors nb x D and nb x D'
- def generate_w_quizzes_(self, nb):
- pass
-
- # save a file to vizualize quizzes, you can save a txt or png file
- def save_quiz_illustrations(
- self,
- result_dir,
- filename_prefix,
- prompts,
- answers,
- predicted_prompts=None,
- predicted_answers=None,
- ):
- pass
-
def fill_cache(self):
while True:
- prompts, answers = self.generate_w_quizzes_(self.chunk_size)
-
- self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
+ quizzes = self.generate_w_quizzes_(self.chunk_size)
+ self.queue.put(quizzes.to("cpu"), block=True)
def generate_w_quizzes(self, nb):
if self.queue is None:
return self.generate_w_quizzes_(nb)
if self.rest is not None:
- prompts, answers = rest
+ quizzes = rest
else:
- prompts, answers = [], []
+ quizzes = []
self.rest = None
- n = sum([p.size(0) for p in prompts])
+ n = sum([q.size(0) for q in quizzes])
with tqdm.tqdm(
total=nb,
desc="world generation",
) as pbar:
while n < nb:
- p, s = self.queue.get(block=True)
- prompts.append(p)
- answers.append(s)
- n += p.size(0)
- pbar.update(p.size(0))
+ q = self.queue.get(block=True)
+ quizzes.append(q)
+ n += q.size(0)
+ pbar.update(q.size(0))
- prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
- assert n == prompts.size(0)
+ quizzes = torch.cat(quizzes, dim=0)
+ assert n == quizzes.size(0)
k = n - nb
if k > 0:
- rest = (prompts[-k:], answers[-k:])
- prompts, answers = prompts[:-k], answers[:-k]
+ rest = quizzes[-k:]
+ quizzes = quizzes[:-k]
- return prompts, answers
+ return quizzes
+
+ ######################################################################
+
+ def trivial_prompts_and_answers(self, prompts, answers):
+ pass
+
+ # The one to implement, returns two tensors nb x D and nb x D'
+ def generate_w_quizzes_(self, nb):
+ pass
+
+ # save a file to vizualize quizzes, you can save a txt or png file
+ def save_quiz_illustrations(
+ self,
+ result_dir,
+ filename_prefix,
+ prompts,
+ answers,
+ predicted_prompts=None,
+ predicted_answers=None,
+ ):
+ pass
def save_some_examples(self, result_dir):
pass
+
+ ######################################################################
######################################################################
- def produce_results(
- self, n_epoch, model, input, result_dir, deterministic_synthesis
- ):
- def predict(input, struct, mask):
- ar_mask = self.problem.make_ar_mask(
- quizzes=quizzes, struct=struct, mask=mask
- )
- result = quizzes * (1 - ar_mask)
- seq_logproba = torch.empty(fwd_quizzes, device=self.device)
+ def predict(self, input, struct, mask):
+ ar_mask = self.problem.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask)
+ result = quizzes * (1 - ar_mask)
- masked_inplace_autoregression(
- model=model,
- batch_size=self.batch_size,
- input=result,
- ar_mask=ar_mask,
- seq_logproba=seq_logproba,
- deterministic_synthesis=deterministic_synthesis,
- progress_bar_desc="accuracy",
- device=self.device,
- )
+ seq_logproba = torch.empty(fwd_quizzes, device=self.device)
- nb_correct = (result == quizzes).min(dim=1).long()
+ masked_inplace_autoregression(
+ model=model,
+ batch_size=self.batch_size,
+ input=result,
+ ar_mask=ar_mask,
+ seq_logproba=seq_logproba,
+ deterministic_synthesis=deterministic_synthesis,
+ progress_bar_desc="accuracy",
+ device=self.device,
+ )
+
+ nb_correct = (result == quizzes).min(dim=1).long()
- return result, correct
+ return result, correct
+ def produce_results(
+ self, n_epoch, model, input, result_dir, deterministic_synthesis
+ ):
input = input.to(self.device)
i = self.problem.indices_select(quizzes=input, struct=struct)
+ input_fwd = input[i]
test_result_fwd, test_correct_fwd = predict(
- input[i], ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+ input_fwd, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
)
input_bck = self.problem.reconfigure(
struct=("A", "f_A", "B", "f_B"),
)
- l = input_bck.size(1)
+ l = input_bck.size(1) // 4
input_bck[:, 3 * l :] = input[i == False][:, :l]
+
test_result_bck, test_correct_bck = predict(
input_bck, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
)
##############################
+ test_result = torch.cat([test_result_fwd[:64], test_result_bck[:64]], dim=0)
+ test_correct = torch.cat([test_correct_fwd[:64], test_correct_bck[:64]], dim=0)
+
self.save_quiz_illustrations(
result_dir,
f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
- quizzes=test_result[:128],
- mistakes=test_correct[:128] * 2 - 1,
+ quizzes=test_result,
+ # mistakes=test_correct,
)
return main_test_accuracy
######################################################################
def flip_half_in_place(self, quizzes):
- r = torch.randint(quizzes.size(0), device=quizzes.device) < 0.5
- i = self.problem.indices_select(quizzes=input, struct=("A", "f_A", "B", "f_B"))
+ r = torch.rand(quizzes.size(0), device=quizzes.device) < 0.5
+ i = self.problem.indices_select(
+ quizzes=quizzes, struct=("A", "f_A", "B", "f_B")
+ )
quizzes[i & r] = self.problem.reconfigure(
quizzes[i & r], struct=("f_B", "f_A", "B", "A")
)
- j = self.problem.indices_select(quizzes=input, struct=("f_B", "f_A", "B", "A"))
+ j = self.problem.indices_select(
+ quizzes=quizzes, struct=("f_B", "f_A", "B", "A")
+ )
quizzes[j & r] = self.problem.reconfigure(
quizzes[j & r], struct=("A", "f_A", "B", "f_B")
)
):
c_quizzes = torch.empty(
nb,
- self.prompt_len + self.answer_len,
+ self.problem.seq_len,
device=self.device,
dtype=torch.int64,
)