parser.add_argument("--reboot", action="store_true", default=False)
# ----------------------------------
+
parser.add_argument("--model", type=str, default="37M")
parser.add_argument("--dim_model", type=int, default=None)
parser.add_argument("--dropout", type=float, default=0.5)
# ----------------------------------
+
parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
parser.add_argument("--problem", type=str, default="grids")
parser.add_argument("--max_fail_to_validate", type=int, default=3)
-parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
+parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.98)
parser.add_argument("--proba_understands", type=float, default=0.95)
parser.add_argument("--test", type=str, default=None)
-parser.add_argument("--logit_std_max", type=float, default=-1)
-
######################################################################
grids_tasks = ", ".join(
subparam._grad.data = subparam._grad.data.to(device)
-######################################################################
-
-
-def run_tests(model, quiz_machine, local_device=main_device):
- with torch.autograd.no_grad():
- model.to(local_device).eval()
-
- nb_test_samples, acc_test_loss = 0, 0.0
- nb_samples_accumulated = 0
-
- full_input, _, full_mask_loss = quiz_machine.data_input(
- args.nb_test_samples, model.test_c_quiz_bags, args.c_quiz_multiplier
- )
- src = zip(
- full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
- )
-
- for input, mask_loss in tqdm.tqdm(
- src,
- dynamic_ncols=True,
- desc="test",
- total=full_input.size(0) // args.batch_size,
- ):
- input = input.to(local_device)
- mask_loss = mask_loss.to(local_device)
- targets = input
-
- output = model(input)
- loss_per_token = F.cross_entropy(
- output.transpose(1, 2), targets, reduction="none"
- )
- loss = (loss_per_token * mask_loss).mean()
- acc_test_loss += loss.item() * input.size(0)
- nb_test_samples += input.size(0)
-
- test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
-
- log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}")
-
- input, _, _ = quiz_machine.data_input(
- 2000, model.test_c_quiz_bags, args.c_quiz_multiplier
- )
-
- model.test_accuracy = quiz_machine.produce_results(
- n_epoch=n_epoch,
- model=model,
- input=input,
- result_dir=args.result_dir,
- )
-
-
-######################################################################
-
-
-def one_epoch(model, quiz_machine, local_device=main_device):
- model.to(local_device).train()
- optimizer_to(model.optimizer, local_device)
-
- nb_train_samples, acc_train_loss = 0, 0.0
-
- full_input, _, full_mask_loss = quiz_machine.data_input(
- args.nb_train_samples,
- model.train_c_quiz_bags + common_c_quiz_bags,
- args.c_quiz_multiplier,
- )
- src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size))
-
- for input, mask_loss in tqdm.tqdm(
- src,
- dynamic_ncols=True,
- desc="training",
- total=full_input.size(0) // args.batch_size,
- ):
- input = input.to(local_device)
- mask_loss = mask_loss.to(local_device)
-
- if nb_train_samples % args.batch_size == 0:
- model.optimizer.zero_grad()
-
- targets = input
- output = model(input)
- loss = F.cross_entropy(output.transpose(1, 2), targets, reduction="none")
- loss = (loss * mask_loss).mean() + model.loss
-
- acc_train_loss += loss.item() * input.size(0)
- nb_train_samples += input.size(0)
-
- loss.backward()
-
- if nb_train_samples % args.batch_size == 0:
- model.optimizer.step()
-
- train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
-
- log_string(f"train_perplexity {n_epoch} model {model.id} {train_perplexity}")
-
- run_tests(model, quiz_machine)
-
- model.to(main_device)
- optimizer_to(model.optimizer, main_device)
-
-
-######################################################################
-
-
-def model_modifier_hot(model):
- model.temperature = args.temperature_hot
- # model.set_noise_injection(1.0, ("ffw", args.nb_blocks // 2))
-
-
-def model_modifier_cold(model):
- model.temperature = args.temperature_cold
- # pass
-
-
-c_quizzes_procedure = [
- (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_modifier_hot),
- (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_modifier_cold),
- (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold),
- # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold),
-]
-
-# quad_order, quad_generate, quad_noise, quad_loss
-
-data_structures = [
- (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
- (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)),
- (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)),
- (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)),
- (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
-]
-
-######################################################################
-
-
-def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
- nb_validated, nb_to_validate = 0, (nb_for_train + nb_for_test) * len(models)
- nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate
-
- start_time = time.perf_counter()
-
- for model in models:
- model.recorded_c_quizzes = []
-
- teaching_count = torch.zeros(len(models), len(models), dtype=torch.int64)
-
- while nb_validated < nb_to_validate:
- model_for_generation = models[torch.randint(len(models), (1,)).item()]
-
- # We generate quizzes with a procedure that injects some
- # structured noise
-
- c_quizzes = quiz_machine.generate_c_quizzes(
- nb_to_generate_per_iteration,
- model_for_generation=model,
- procedure=c_quizzes_procedure,
- )
-
- nb_generated += c_quizzes.size(0)
-
- # We discard the trivial ones, according to a criterion
- # specific to the world quizzes (e.g. B=f(B))
-
- to_keep = quiz_machine.problem.trivial(c_quizzes) == False
-
- c_quizzes = c_quizzes[to_keep]
-
- # Compute the responses of all the models on the c_quizzes,
- # and their proba estimates of their responses
-
- solved_c_quizzes = c_quizzes[:, None, :].expand(-1, len(models), -1).clone()
-
- proba_own_solution = torch.zeros(
- c_quizzes.size(0), len(models), device=solved_c_quizzes.device
- )
-
- for model in models:
- (solved_c_quizzes[:, model.id], _, _) = quiz_machine.predict(
- model,
- solved_c_quizzes[:, model.id],
- quad_orders=("A", "f_A", "B", "f_B"),
- quad=(0, 0, 0, 1),
- )
-
- proba_own_solution[:, model.id] = model_proba_solutions(
- model, solved_c_quizzes[:, model.id]
- )
-
- # Now for every model not confident of its response, we pick
- # the most consistent from a model which is confident
-
- for s in range(proba_own_solution.size(0)):
- # At least one GPT does not understand at all
- if proba_own_solution[s, :].min() < args.proba_not_understands:
- dont_get_this_quiz = proba_own_solution[s, :] < args.proba_understands
- nb_fails = dont_get_this_quiz.long().sum()
- # At most max_fail_to_validate do not understand (default 3/5)
- if nb_fails >= 1 and nb_fails <= args.max_fail_to_validate:
- for model in models:
- # If a GPT does not get that quiz
- if dont_get_this_quiz[model.id]:
- assert (
- proba_own_solution[s, model.id] < args.proba_understands
- )
- # Look at its estimate of the others'solutions
- proba_other_solutions = model_proba_solutions(
- model, solved_c_quizzes[s]
- )
- # Randomize a bit the orders for the frequent P=1
- proba_other_solutions += (
- torch.rand(proba_other_solutions.size()) * 1e-6
- )
- # Remove the under threshold confidence solutions
- proba_other_solutions[dont_get_this_quiz] = -1
- i = proba_other_solutions.argmax()
- model.recorded_c_quizzes.append(solved_c_quizzes[s, i])
- teaching_count[i, model.id] += 1
- nb_validated += 1
-
- duration = time.perf_counter() - start_time
-
- if nb_validated > 0:
- if nb_validated < nb_to_validate:
- d = (nb_to_validate - nb_validated) * duration / nb_validated
- e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime(
- "%a %H:%M"
- )
- else:
- e = "now!"
- else:
- e = "???"
-
- log_string(
- f"keep c_quizzes model {model_for_generation.id} validated nb_validated {nb_validated} / {nb_to_validate} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h) proportion_kept {nb_validated * 100 / nb_generated:.02f}%"
- )
-
- for s in range(teaching_count.size(0)):
- o = [x.item() for x in teaching_count[s]]
- log_string(f"teacher model {s} to {o}")
-
- for model in models:
- new_bag = torch.cat([q[None, :] for q in model.recorded_c_quizzes], dim=0)
-
- if new_bag.size(0) > 0:
- n = (new_bag.size(0) * nb_for_train) // (nb_for_train + nb_for_test)
- if n > 0:
- model.train_c_quiz_bags.append(new_bag[:n])
- if n < new_bag.size(0):
- model.test_c_quiz_bags.append(new_bag[n:])
-
- c_quizzes = new_bag[:128]
-
- l = [model_proba_solutions(model, c_quizzes) for model in models]
- probas = torch.cat([x[:, None] for x in l], dim=1)
- comments = []
-
- for l in probas:
- comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
-
- filename = f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}.png"
- quiz_machine.problem.save_quizzes_as_image(
- args.result_dir, filename, c_quizzes, comments=comments
- )
-
- log_string(
- f"nb_c_quizzes model {model.id} train {sum([q.size(0) for q in model.train_c_quiz_bags ])} test {sum([q.size(0) for q in model.test_c_quiz_bags ])}"
- )
-
-
######################################################################
from mygpt import (
######################################################################
+# quad_order, quad_generate, quad_noise, quad_loss
+
+data_structures = [
+ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
+ (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)),
+ (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)),
+ (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)),
+ (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
+]
+
def ae_batches(
quiz_machine,
f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"
)
- run_ae_test(model, quiz_machine, n_epoch, c_quizzes, local_device=local_device)
+ run_ae_test(model, quiz_machine, n_epoch, local_device=local_device)
######################################################################