X[i, j] = c[1]
f_X[...] = (A == A[i, j]) * c[1] + ((A > 0) & (A != A[i, j])) * c[0]
+ # @torch.compile
+ def task_stack(self, A, f_A, B, f_B):
+ N = 5
+ c = torch.randperm(len(self.colors) - 1)[:N] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ i1, j1, i2, j2 = (
+ self.height // 2 - 1,
+ self.width // 2 - 1,
+ self.height // 2 + 1,
+ self.width // 2 + 1,
+ )
+ op = torch.tensor((0, 1, 2, 3) * 4)
+ op = op[torch.randperm(op.size(0))[:9]]
+ for q in range(op.size(0)):
+ u = 3 * (q // 3)
+ v = 3 * (q % 3)
+ d = c[torch.randint(N, (1,)).item()]
+ # X[u+1,v+1]=d
+ if op[q] == 0: # right
+ X[u : u + 3, v + 2] = d
+ elif op[q] == 1: # let
+ X[u : u + 3, v] = d
+ elif op[q] == 2: # bottom
+ X[u + 2, v : v + 3] = d
+ elif op[q] == 3: # top
+ X[u, v : v + 3] = d
+
+ if q == 0:
+ f_X[i1:i2, j1:j2] = d
+ elif op[q] == 0: # right
+ f_X[i1:i2, j2] = d
+ j2 += 1
+ elif op[q] == 1: # let
+ j1 -= 1
+ f_X[i1:i2, j1] = d
+ elif op[q] == 2: # bottom
+ f_X[i2, j1:j2] = d
+ i2 += 1
+ elif op[q] == 3: # top
+ i1 -= 1
+ f_X[i1, j1:j2] = d
+
+ def randint(self, *m):
+ m = torch.tensor(m)
+ return (torch.rand(m.size()) * m).long()
+
+ def task_matrices(self, A, f_A, B, f_B):
+ N = 6
+ c = torch.randperm(len(self.colors) - 1)[:N] + 1
+
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ M1 = torch.randint(2, (5, 5))
+ M2 = torch.randint(2, (5, 5))
+ P = M1 @ M2
+ for i in range(5):
+ for j in range(5):
+ X[i, j] = c[M1[i, j]]
+ X[i, j + 5] = c[M2[i, j]]
+ f_X[i, j] = c[M1[i, j]]
+ f_X[i, j + 5] = c[M2[i, j]]
+ f_X[i + 5, j + 5] = c[P[i, j]]
+
######################################################################
def trivial_prompts_and_answers(self, prompts, answers):
# nb, nrow = 8, 2
# for t in grids.all_tasks:
- for t in [grids.task_count]:
+ for t in [grids.task_matrices]:
print(t.__name__)
prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
grids.save_quiz_illustrations(
parser.add_argument("--proba_not_understands", type=float, default=0.5)
-# parser.add_argument("--generation_temperature", type=float, default=2)
+parser.add_argument("--generation_temperature", type=float, default=2)
parser.add_argument("--c_quiz_validation_mode", type=str, default="predict")
start_time = time.perf_counter()
+ nb_validated = torch.zeros(len(models))
+
while nb_validated < nb_to_create:
- model_for_generation = models[torch.randint(len(models), (1,))]
+ # We balance the number of quizzes per model
+
+ model_for_generation = models[nb_validated.argmin()]
c_quizzes = quiz_machine.generate_c_quizzes(
nb_to_generate_per_iteration,
model_for_generation=model_for_generation,
forward_only=args.forward_only,
+ generation_temperature=args.generation_temperature,
)
c_quizzes = keep_good_quizzes(models, c_quizzes)
- nb_validated += c_quizzes.size(0)
+ nb_validated[model.id] += c_quizzes.size(0)
+ total_nb_validated = nb_validated.sum()
recorded.append(c_quizzes)
duration = time.perf_counter() - start_time
- if nb_validated > 0 and nb_validated < nb_to_create:
- d = (nb_to_create - nb_validated) * duration / nb_validated
+ if total_nb_validated > 0 and total_nb_validated < nb_to_create:
+ d = (nb_to_create - total_nb_validated) * duration / total_nb_validated
e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime(
"%a %H:%M"
)
e = "???"
log_string(
- f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create} (finishes {e})"
+ f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e})"
)
validated_quizzes = torch.cat(recorded, dim=0)
###############################################################
- def generate_c_quizzes(self, nb, model_for_generation, forward_only=False):
+ def generate_c_quizzes(
+ self,
+ nb,
+ model_for_generation,
+ forward_only=False,
+ generation_temperature=1.0
+ ):
c_quizzes = torch.empty(
nb,
self.prompt_len + self.answer_len + 2,
input=c_quizzes,
ar_mask=self.make_ar_mask(c_quizzes, first=True),
seq_logproba=seq_logproba,
- temperature=1.0,
+ temperature=generation_temperature,
deterministic_synthesis=False,
device=self.device,
)
input=c_quizzes,
ar_mask=self.make_ar_mask(c_quizzes),
seq_logproba=seq_logproba,
- temperature=1,
+ temperature=1.0
deterministic_synthesis=False,
device=self.device,
)
input=c_quizzes,
ar_mask=self.make_ar_mask(c_quizzes, first=True),
seq_logproba=seq_logproba,
- temperature=1.0,
+ temperature=generation_temperature,
deterministic_synthesis=False,
device=self.device,
)