while True:
error = False
- N = torch.randint(5, (1,)).item() + 2
- c = torch.zeros(N + 1, dtype=torch.int64)
- c[1:] = torch.randperm(len(self.colors) - 1)[:N] + 1
+ N = 3
+ c = torch.zeros(N + 2, dtype=torch.int64)
+ c[1:] = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
if not hasattr(self, "cache_count") or len(self.cache_count) == 0:
self.height,
self.width,
nb_seeds=self.height * self.width // 8,
- nb_iterations=self.height * self.width // 10,
+ nb_iterations=self.height * self.width // 5,
)
)
# V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % (
# c.size(0) - 1
# ) + 1
- V = torch.randint(c.size(0) - 1, (X.max() + 1,)) + 1
+
+ V = torch.randint(N, (X.max() + 1,)) + 1
V[0] = 0
NB = F.one_hot(c[V]).sum(dim=0)
X[...] = c[V[X]]
-
- if F.one_hot(X.flatten()).max(dim=0).values.sum().item() == N + 1:
- f_X[...] = 0
- for e in range(1, N + 1):
- for j in range(NB[c[e]]):
- if j < self.width:
- f_X[e - 1, j] = c[e]
- else:
- error = True
- break
+ f_X[...] = X
+
+ if F.one_hot(X.flatten()).max(dim=0).values.sum().item() >= 3:
+ m = NB[c[:-1]].max()
+ if (NB[c[:-1]] == m).long().sum() == 1:
+ for e in range(1, N + 1):
+ if NB[c[e]] == m:
+ a = (f_X == c[e]).long()
+ f_X[...] = (1 - a) * f_X + a * c[-1]
else:
error = True
break
# nb, nrow = 8, 2
# for t in grids.all_tasks:
- for t in [grids.task_contact]:
+ for t in [grids.task_count]:
# for t in [grids.task_symbols]:
print(t.__name__)
prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
parser.add_argument("--temperature_cold", type=float, default=0.75)
+parser.add_argument("--nb_rounds", type=int, default=3)
+
parser.add_argument("--c_quiz_validation_mode", type=str, default="predict")
parser.add_argument("--p2a_only", action="store_true", default=False)
######################################################################
-def keep_good_quizzes(models, quizzes, required_nb_failures=1):
- quizzes = quizzes[quiz_machine.non_trivial(quizzes)]
-
- if args.c_quiz_validation_mode == "proba":
- token_logprobas = quiz_machine.solution_token_logprobas(models, quizzes)
- l = token_logprobas.sum(dim=-1).sort(dim=-1).values
-
- to_keep = (l[:, 0] < math.log(args.proba_not_understands)) & (
- l[:, 1] > math.log(args.proba_understands)
- )
-
- elif args.c_quiz_validation_mode == "predict":
- nc = quiz_machine.solution_nb_correct(models, quizzes)
-
- count_nc = tuple(
- n.item() for n in F.one_hot(nc, num_classes=len(models) + 1).sum(dim=0)
- )
-
- log_string(f"nb_correct {count_nc}")
-
- to_keep = nc == (len(models) - required_nb_failures)
-
- else:
- raise ValueError(f"{args.c_quiz_validation_mode=}")
-
- if args.dirty_debug:
- # warnings.warn("DEBUG", RuntimeWarning)
- to_keep = torch.rand(to_keep.size(), device=to_keep.device) < 0.5
-
- return quizzes[to_keep]
-
-
-######################################################################
-
-
def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
- nb_to_create = nb_for_train + nb_for_test
- nb_to_generate_per_iteration = nb_to_create
+ nb_to_validate = nb_for_train + nb_for_test
+ nb_to_generate_per_iteration = nb_to_validate
nb_validated = 0
recorded_validated = []
- recorded_too_simple = []
+ # recorded_too_simple = []
start_time = time.perf_counter()
- nb_validated = torch.zeros(len(models), dtype=torch.int64)
+ nb_validated_per_model = torch.zeros(len(models), dtype=torch.int64)
- while nb_validated.sum() < nb_to_create:
+ while nb_validated_per_model.sum() < nb_to_validate:
# We balance the number of quizzes per model
- model_for_generation = sorted(models, key=lambda m: nb_validated[m.id])[0]
+ model_for_generation = sorted(
+ models, key=lambda m: nb_validated_per_model[m.id]
+ )[0]
+
+ # We generate quizzes with a procedure that injects some
+ # structured noise
c_quizzes = quiz_machine.generate_c_quizzes(
nb_to_generate_per_iteration,
temperature_cold=args.temperature_cold,
)
- c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
-
- nc = quiz_machine.solution_nb_correct(models, c_quizzes)
-
- count_nc = tuple(
- n.item() for n in F.one_hot(nc, num_classes=len(models) + 1).sum(dim=0)
- )
-
- log_string(f"nb_correct {count_nc}")
+ # We discard the trivial ones
- recorded_too_simple.append(c_quizzes[nc == len(models)])
+ c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
- c_quizzes = c_quizzes[nc == len(models) - 1]
+ # We go through nb_rounds rounds and keep only quizzes on
+ # which models respond always the same through rounds
- nb_validated[model_for_generation.id] += c_quizzes.size(0)
- total_nb_validated = nb_validated.sum().item()
+ total_nb_validated = 0
+ ms = 0
+ for r in range(args.nb_rounds):
+ ms += quiz_machine.models_successes(models, c_quizzes)
+ # print(f"{r=} {ms=}")
+ i = ((ms == r + 1).long().sum(dim=1) == ms.size(1) - 1) & (
+ (ms == 0).long().sum(dim=1) == 1
+ )
+ c_quizzes = c_quizzes[i]
+ ms = ms[i]
+ if c_quizzes.size(0) == 0:
+ break
- recorded_validated.append(c_quizzes)
+ if c_quizzes.size(0) > 0:
+ nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0)
+ total_nb_validated = nb_validated_per_model.sum().item()
+ recorded_validated.append(c_quizzes)
duration = time.perf_counter() - start_time
if total_nb_validated > 0:
- if total_nb_validated < nb_to_create:
- d = (nb_to_create - total_nb_validated) * duration / total_nb_validated
+ if total_nb_validated < nb_to_validate:
+ d = (
+ (nb_to_validate - total_nb_validated)
+ * duration
+ / total_nb_validated
+ )
e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime(
"%a %H:%M"
)
e = "???"
log_string(
- f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)"
+ f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)"
)
validated_quizzes = torch.cat(recorded_validated, dim=0)
- too_simple_quizzes = torch.cat(recorded_too_simple, dim=0)
+ # too_simple_quizzes = torch.cat(recorded_too_simple, dim=0)
######################################################################
# store the new c_quizzes which have been validated
quiz_machine.store_c_quizzes(v_train, for_train=True)
quiz_machine.store_c_quizzes(quiz_machine.problem.p_a_flip(v_train), for_train=True)
- v_test = validated_quizzes[nb_for_train:nb_to_create]
+ v_test = validated_quizzes[nb_for_train:nb_to_validate]
quiz_machine.store_c_quizzes(v_test, for_train=False)
quiz_machine.store_c_quizzes(quiz_machine.problem.p_a_flip(v_test), for_train=False)
args.result_dir, prefix, vq, show_part_to_predict=False
)
- vq = too_simple_quizzes[torch.randperm(too_simple_quizzes.size(0))[:128]]
+ # vq = too_simple_quizzes[torch.randperm(too_simple_quizzes.size(0))[:128]]
- if vq.size(0) > 0:
- prefix = f"culture_c_quiz_{n_epoch:04d}_too_simple"
- quiz_machine.save_quiz_illustrations(
- args.result_dir, prefix, vq, show_part_to_predict=False
- )
+ # if vq.size(0) > 0:
+ # prefix = f"culture_c_quiz_{n_epoch:04d}_too_simple"
+ # quiz_machine.save_quiz_illustrations(
+ # args.result_dir, prefix, vq, show_part_to_predict=False
+ # )
######################################################################
input,
ar_mask,
seq_logproba,
- temperature,
- deterministic_synthesis,
+ logit_transformer=None,
+ deterministic_synthesis=False,
):
to_generate = (ar_mask.sum(0) > 0).nonzero()
logits = output[:, s]
- logits = (logits / temperature).log_softmax(dim=-1)
+ logits = logit_transformer(s, logits).log_softmax(dim=-1)
if deterministic_synthesis:
t_next = logits.argmax(-1)
input,
ar_mask,
seq_logproba,
- temperature,
- deterministic_synthesis,
+ logit_transformer=None,
+ deterministic_synthesis=False,
forbidden_tokens=None,
logit_biases=None,
progress_bar_desc=None,
input=input,
ar_mask=ar_mask,
seq_logproba=seq_logproba,
- temperature=temperature,
+ logit_transformer=logit_transformer,
deterministic_synthesis=deterministic_synthesis,
)
input=result,
ar_mask=ar_mask,
seq_logproba=seq_logproba,
- temperature=1.0,
deterministic_synthesis=deterministic_synthesis,
progress_bar_desc="accuracy",
device=self.device,
###############################################################
- def solution_nb_correct(self, models_for_validation, c_quizzes):
+ def models_successes(self, models_for_validation, c_quizzes):
seq_logproba = torch.zeros(
c_quizzes.size(0),
max([m.id for m in models_for_validation]) + 1,
device=self.device,
)
- nb_correct = 0
- correct_models = torch.empty(
+ correctly_solved = torch.empty(
c_quizzes.size(0),
max([m.id for m in models_for_validation]) + 1,
device=self.device,
input=result,
ar_mask=ar_mask,
seq_logproba=seq_logproba[:, model.id],
- temperature=1.0,
deterministic_synthesis=False,
device=self.device,
)
- correct_models[:, model.id] = (
- (c_quizzes == result).long().min(dim=-1).values
- )
+ correct = (c_quizzes == result).long().min(dim=-1).values
# -------------------------------
input=result,
ar_mask=ar_mask,
seq_logproba=seq_logproba[:, model.id],
- temperature=1.0,
deterministic_synthesis=False,
device=self.device,
)
- correct_models[:, model.id] *= (
- (c_quizzes == result).long().min(dim=-1).values
- )
+ correct *= (c_quizzes == result).long().min(dim=-1).values
# -------------------------------
- i = correct_models.sum(dim=1) == correct_models.size(1) - 1
- c = (correct_models[i] == 0).long().sum(dim=0)
- self.logger(f"nb_failures_on_validated {tuple(x.item() for x in c)}")
+ correctly_solved[:, model.id] = correct
- return correct_models.sum(dim=1).to("cpu")
+ return correctly_solved.to("cpu")
###############################################################
seq_logproba = torch.zeros(nb, device=self.device)
+ def heater(T):
+ return lambda s, logits: logits / T
+
if p2a_only:
c_quizzes[...] = self.problem.token_forward
input=c_quizzes,
ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
seq_logproba=seq_logproba,
- temperature=temperature_hot,
+ logit_transformer=heater(temperature_hot),
deterministic_synthesis=False,
device=self.device,
)
input=c_quizzes,
ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
seq_logproba=seq_logproba,
- temperature=temperature_cold,
+ logit_transformer=heater(temperature_cold),
deterministic_synthesis=False,
device=self.device,
)
input=c_quizzes,
ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
seq_logproba=seq_logproba,
- temperature=temperature_hot,
+ logit_transformer=heater(temperature_hot),
deterministic_synthesis=False,
device=self.device,
)
input=c_quizzes,
ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
seq_logproba=seq_logproba,
- temperature=temperature_cold,
+ logit_transformer=heater(temperature_cold),
deterministic_synthesis=False,
device=self.device,
)
input=c_quizzes,
ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
seq_logproba=seq_logproba,
- temperature=temperature_cold,
+ logit_transformer=heater(temperature_cold),
deterministic_synthesis=False,
device=self.device,
)