targets = input
input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA
+
result = (1 - mask_generate) * input + mask_generate * torch.randint(
quiz_machine.problem.nb_colors, input.size(), device=input.device
)
result[not_converged] = update[not_converged]
not_converged = (pred_result != result).max(dim=1).values
nb_it += 1
- print("DEBUG", nb_it, i.long().sum().item())
- if not i.any() or nb_it > 100:
+ print("DEBUG", nb_it, not_converged.long().sum().item())
+ if not not_converged.any() or nb_it > 100:
break
correct = (result == targets).min(dim=1).values.long()
self.answer_len = None
self.prompt_noise = prompt_noise
- # struct, quad_generate, quad_noise, quad_loss
+ # quad_order, quad_generate, quad_noise, quad_loss
self.train_structures = [
(("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
(("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
else:
quizzes = self.problem.generate_w_quizzes(nb_samples)
+ # shuffle
+
i = torch.randperm(quizzes.size(0), device=quizzes.device)
quizzes = quizzes[i]
- self.randomize_configuations_inplace(
- quizzes, quad_orders=[s for s, _, _, _ in data_structures]
- )
+ # Re-order and inject noise
quiz_mask_generate = quizzes.new_full(quizzes.size(), 1)
quiz_mask_loss = quizzes.new_full(quizzes.size(), 1)
-
- for quad_order, quad_generate, quad_noise, quad_loss in data_structures:
- i = self.problem.indices_select(quizzes=quizzes, quad_order=quad_order)
- if i.any():
- if self.prompt_noise > 0.0:
- quizzes[i] = self.problem.inject_noise(
- quizzes[i],
- self.prompt_noise,
- quad_order=quad_order,
- quad_noise=quad_noise,
- )
- quiz_mask_generate[i] = self.make_quiz_mask(
- quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_generate
- )
- quiz_mask_loss[i] = self.make_quiz_mask(
- quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_loss
+ order_ids = torch.randint(len(data_structures), (quizzes.size(0),))
+
+ for j, s in enumerate(data_structures):
+ quad_order, quad_generate, quad_noise, quad_loss = s
+ i = order_ids == j
+ quizzes[i] = self.problem.reconfigure(quizzes[i], quad_order=quad_order)
+ if self.prompt_noise > 0.0:
+ quizzes[i] = self.problem.inject_noise(
+ quizzes[i],
+ self.prompt_noise,
+ quad_order=quad_order,
+ quad_noise=quad_noise,
)
+ quiz_mask_generate[i] = self.make_quiz_mask(
+ quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_generate
+ )
+ quiz_mask_loss[i] = self.make_quiz_mask(
+ quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_loss
+ )
return quizzes, quiz_mask_generate, quiz_mask_loss