parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None)
-parser.add_argument("--c_quiz_multiplier", type=int, default=1)
-
parser.add_argument("--learning_rate", type=float, default=5e-4)
parser.add_argument("--lambda_H", type=float, default=0.0)
nb_samples_accumulated = 0
full_input, full_mask_loss = quiz_machine.data_input(
- args.nb_test_samples, model.test_c_quiz_bags, args.c_quiz_multiplier
+ args.nb_test_samples, model.test_c_quiz_bags
)
src = zip(
full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}")
- input, _ = quiz_machine.data_input(
- 2000, model.test_c_quiz_bags, args.c_quiz_multiplier
- )
+ input, _ = quiz_machine.data_input(2000, model.test_c_quiz_bags)
model.test_accuracy = quiz_machine.produce_results(
n_epoch=n_epoch,
nb_train_samples, acc_train_loss = 0, 0.0
full_input, full_mask_loss = quiz_machine.data_input(
- args.nb_train_samples, model.train_c_quiz_bags, args.c_quiz_multiplier
+ args.nb_train_samples, model.train_c_quiz_bags
)
src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size))
return l.exp()
-def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
+def create_c_quizzes(main_model, other_models, quiz_machine, nb_for_train, nb_for_test):
nb_validated, nb_to_validate = 0, (nb_for_train + nb_for_test) * len(models)
nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate
start_time = time.perf_counter()
- for model in models:
- model.recorded_c_quizzes = []
-
- teaching_count = torch.zeros(len(models), len(models), dtype=torch.int64)
+ recorded = []
while nb_validated < nb_to_validate:
- model_for_generation = models[torch.randint(len(models), (1,)).item()]
-
# We generate quizzes with a procedure that injects some
# structured noise
c_quizzes = quiz_machine.generate_c_quizzes(
nb_to_generate_per_iteration,
- model_for_generation=model,
+ model_for_generation=main_model,
procedure=c_quizzes_procedure,
)
c_quizzes = c_quizzes[to_keep]
- # Compute the responses of all the models on the c_quizzes,
- # and their proba estimates of their responses
+ # Keep only the quizzes that the main model cannot solve
- solved_c_quizzes = c_quizzes[:, None, :].expand(-1, len(models), -1).clone()
+ solved_c_quizzes = c_quizzes.clone()
- proba_own_solution = torch.zeros(
- c_quizzes.size(0), len(models), device=solved_c_quizzes.device
+ main_solution, _, _ = quiz_machine.predict(
+ main_model,
+ solved_c_quizzes,
+ struct=("A", "f_A", "B", "f_B"),
+ mask=(0, 0, 0, 1),
)
- for model in models:
- (solved_c_quizzes[:, model.id], _, _) = quiz_machine.predict(
- model,
- solved_c_quizzes[:, model.id],
- struct=("A", "f_A", "B", "f_B"),
- mask=(0, 0, 0, 1),
- )
+ keep = (
+ model_proba_solutions(main_model, main_solution)
+ < args.proba_not_understands
+ )
+ c_quizzes = c_quizzes[keep]
+
+ # If there are some quizzes that the main model cannot solve,
+ # pick the most confident solution
- proba_own_solution[:, model.id] = model_proba_solutions(
- model, solved_c_quizzes[:, model.id]
+ if c_quizzes.size(0) > 0:
+ solution = c_quizzes.clone()
+ c_quizzes_proba = torch.zeros(
+ solution.size(0), dtype=torch.float32, device=solution.device
)
- # Now for every model not confident of its response, we pick
- # the most consistent from a model which is confident
-
- for s in range(proba_own_solution.size(0)):
- # At least one GPT does not understand at all
- if proba_own_solution[s, :].min() < args.proba_not_understands:
- dont_get_this_quiz = proba_own_solution[s, :] < args.proba_understands
- nb_fails = dont_get_this_quiz.long().sum()
- # At most max_fail_to_validate do not understand (default 3/5)
- if nb_fails >= 1 and nb_fails <= args.max_fail_to_validate:
- for model in models:
- # If a GPT does not get that quiz
- if dont_get_this_quiz[model.id]:
- assert (
- proba_own_solution[s, model.id] < args.proba_understands
- )
- # Look at its estimate of the others'solutions
- proba_other_solutions = model_proba_solutions(
- model, solved_c_quizzes[s]
- )
- # Randomize a bit the orders for the frequent P=1
- proba_other_solutions += (
- torch.rand(proba_other_solutions.size()) * 1e-6
- )
- # Remove the under threshold confidence solutions
- proba_other_solutions[dont_get_this_quiz] = -1
- i = proba_other_solutions.argmax()
- model.recorded_c_quizzes.append(solved_c_quizzes[s, i])
- teaching_count[i, model.id] += 1
- nb_validated += 1
+ for model in other_models:
+ solution, _, _ = quiz_machine.predict(
+ model,
+ solution,
+ struct=("A", "f_A", "B", "f_B"),
+ mask=(0, 0, 0, 1),
+ )
+
+ probas = model_proba_solutions(model, solution)
+ keep = probas >= c_quizzes_proba
+ c_quizzes = solution[keep]
+ c_quizzes_proba[keep] = probas[keep]
+
+ keep = c_quizzes_proba >= args.proba_understands
+ recorded.append(c_quizzes_proba[keep])
+ nb_validated += keep.long().sum()
duration = time.perf_counter() - start_time
f"keep c_quizzes model {model_for_generation.id} validated nb_validated {nb_validated} / {nb_to_validate} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h) proportion_kept {nb_validated * 100 / nb_generated:.02f}%"
)
- for s in range(teaching_count.size(0)):
- o = [x.item() for x in teaching_count[s]]
- log_string(f"teacher model {s} to {o}")
+ # Save some images
- for model in models:
- new_bag = torch.cat([q[None, :] for q in model.recorded_c_quizzes], dim=0)
-
- if new_bag.size(0) > 0:
- n = (new_bag.size(0) * nb_for_train) // (nb_for_train + nb_for_test)
- if n > 0:
- model.train_c_quiz_bags.append(new_bag[:n])
- if n < new_bag.size(0):
- model.test_c_quiz_bags.append(new_bag[n:])
-
- c_quizzes = new_bag[:128]
-
- l = [model_proba_solutions(model, c_quizzes) for model in models]
- probas = torch.cat([x[:, None] for x in l], dim=1)
- comments = []
-
- for l in probas:
- comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
-
- filename = f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}.png"
- quiz_machine.problem.save_quizzes_as_image(
- args.result_dir, filename, c_quizzes, comments=comments
- )
+ c_quizzes = torch.cat(recorded, dim=0)
- log_string(
- f"nb_c_quizzes model {model.id} train {sum([q.size(0) for q in model.train_c_quiz_bags ])} test {sum([q.size(0) for q in model.test_c_quiz_bags ])}"
- )
+ l = [
+ model_proba_solutions(model, c_quizzes) for model in [main_model] + other_models
+ ]
+ probas = torch.cat([x[:, None] for x in l], dim=1)
+ comments = []
+ for l in probas:
+ comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
+ filename = f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}.png"
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir, filename, c_quizzes[:128], comments=comments
+ )
-######################################################################
-from mygpt import (
- WithResidual,
- CacheWrapper,
- AddPositionalEncoding,
- QKVAttention,
- BracketedSequence,
+log_string(
+ f"nb_c_quizzes model {model.id} train {sum([q.size(0) for q in model.train_c_quiz_bags ])} test {sum([q.size(0) for q in model.test_c_quiz_bags ])}"
)
-class Thinker(nn.Module):
- def __init__(
- self,
- vocabulary_size,
- dim_model,
- dim_keys,
- dim_hidden,
- nb_heads,
- nb_blocks,
- f_len,
- dropout=0.0,
- len_max=1e5,
- ):
- super().__init__()
-
- assert dim_model % nb_heads == 0
-
- self.embedding = nn.Sequential(
- CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
- AddPositionalEncoding(len_max),
- )
-
- def trunk(depth):
- trunk_blocks = []
-
- for b in range(nb_blocks):
- trunk_blocks += [
- WithResidual(
- CacheWrapper(
- nn.LayerNorm((dim_model,)),
- ),
- QKVAttention(
- dim_in=dim_model,
- dim_qk=dim_keys,
- dim_v=dim_model // nb_heads,
- nb_heads=nb_heads,
- attention_dropout=dropout,
- ),
- ),
- WithResidual(
- CacheWrapper(
- nn.LayerNorm((dim_model,)),
- nn.Linear(in_features=dim_model, out_features=dim_hidden),
- nn.ReLU(),
- nn.Linear(in_features=dim_hidden, out_features=dim_model),
- nn.Dropout(dropout),
- ),
- ),
- ]
-
- return nn.Sequential(*trunk_blocks)
-
- self.bottom_trunk = trunk(nb_blocks // 2)
-
- self.top_trunk = trunk(nb_blocks // 2)
-
- self.readout = CacheWrapper(
- nn.Linear(in_features=dim_model, out_features=vocabulary_size)
- )
-
- self.fun_embedding = nn.Parameter(torch.randn(1, f_len, dim_model))
-
- with torch.no_grad():
- for m in self.modules():
- if isinstance(m, nn.Embedding):
- m.weight.normal_(mean=0, std=2e-2)
- elif isinstance(m, nn.LayerNorm):
- m.bias.zero_()
- m.weight.fill_(1.0)
-
- def forward(self, bs):
- for m in self.modules():
- m.loss = 0
-
- L = bs.x.size(1) // 3
-
- bs = self.embedding(bs)
- A_fA = BracketedSequence(bs.x[:, : 2 * L])
- B = BracketedSequence(bs.x[:, -L:])
-
- bs = BracketedSequence(
- torch.cat([A_fA.x, self.fun_embedding.expand(bs.x.size(0), -1, -1)], dim=1)
- )
- bs = self.bottom_trunk(bs)
- bs = BracketedSequence(torch.cat([bs.x[:, -f_len:, :], B.x], dim=1))
- bs = self.top_trunk(bs)
- bs = BracketedSequence(bs.x[:, f_len:, :])
- bs = self.readout(bs)
-
- for m in self.modules():
- if m is not self:
- self.loss += m.loss
-
- return bs
-
-
######################################################################
models = []
model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
model.test_accuracy = 0.0
- model.best_test_accuracy = 0.0
- model.best_dict = copy.deepcopy(model.state_dict())
models.append(model)
######################################################################
current_epoch = 0
-# We balance the computing time between training the models and
-# generating c_quizzes
-
-total_time_generating_c_quizzes = 0
-total_time_training_models = 0
-
if args.resume:
for model in models:
filename = f"gpt_{model.id:03d}.pth"
model.load_state_dict(d["state_dict"])
model.optimizer.load_state_dict(d["optimizer_state_dict"])
model.test_accuracy = d["test_accuracy"]
- model.best_test_accuracy = d["best_test_accuracy"]
- model.best_dict = d["best_dict"]
model.train_c_quiz_bags = d["train_c_quiz_bags"]
model.test_c_quiz_bags = d["test_c_quiz_bags"]
log_string(f"successfully loaded {filename}")
state = torch.load(os.path.join(args.result_dir, filename))
log_string(f"successfully loaded {filename}")
current_epoch = state["current_epoch"]
- total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"]
- total_time_training_models = state["total_time_training_models"]
except FileNotFoundError:
log_string(f"cannot find {filename}")
pass
return input
-if args.test == "mlp":
- model = models[0]
- tape_input, tape_output = [], []
- L = len(model.trunk)
- model.trunk.insert(L // 2 + 1, Recorder(tape_output))
- model.trunk.insert(L // 2, Recorder(tape_input))
-
- mlp = nn.Sequential(
- nn.Linear(args.dim_model, args.dim_model),
- nn.ReLU(),
- nn.Linear(args.dim_model, args.dim_model),
- nn.ReLU(),
- nn.Linear(args.dim_model, 8 * args.dim_model),
- Folder(),
- Unfolder(404, 8 * args.dim_model),
- nn.Linear(8 * args.dim_model, args.dim_model),
- nn.ReLU(),
- nn.Linear(args.dim_model, args.dim_model),
- nn.ReLU(),
- nn.Linear(args.dim_model, args.dim_model),
- ).to(main_device)
-
- mlp.optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)
-
- for n_epoch in range(args.nb_epochs):
- train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples)
-
- tape_input.clear()
- tape_output.clear()
-
- with torch.autograd.no_grad():
- model.to(main_device).eval()
- for input in train_input.split(args.batch_size):
- input = input.to(main_device)
- output = model(mygpt.BracketedSequence(input)).x
-
- train_input = torch.cat([bs.x for bs in tape_input], dim=0)
- train_targets = torch.cat([bs.x for bs in tape_output], dim=0)
-
- nb_train_samples, acc_train_loss = 0, 0.0
- src = zip(
- train_input.split(args.batch_size), train_targets.split(args.batch_size)
- )
- for input, targets in tqdm.tqdm(
- src,
- dynamic_ncols=True,
- desc="train",
- total=train_input.size(0) // args.batch_size,
- ):
- input = input.to(main_device)
- output = mlp(input)
- loss = F.mse_loss(output, targets) + output.abs().sum()
- acc_train_loss += loss.item() * input.size(0)
- nb_train_samples += input.size(0)
-
- mlp.optimizer.zero_grad()
- loss.backward()
- mlp.optimizer.step()
-
- log_string(f"mlp_loss {n_epoch} train {acc_train_loss/nb_train_samples}")
-
- exit(0)
-
######################################################################
######################################################################
-
-if args.test == "entropy":
- model = models[0]
- model.to(main_device)
-
- optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
-
- log_string("starting testing entropy maximization")
-
- for n_epoch in range(100):
- input = quiz_machine.generate_c_quizzes(
- 128,
- model_for_generation=model,
- procedure=c_quizzes_procedure,
- )
-
- quiz_machine.problem.save_quizzes_as_image(
- args.result_dir,
- f"test_{n_epoch:04d}.png",
- quizzes=input,
- )
-
- log_string(f"wrote {filename}")
-
- with torch.no_grad():
- for p in model.parameters():
- p += torch.randn(p.size(), device=p.device) * 1e-3
-
- # nb_train_samples, acc_train_loss = 0, 0.0
-
- # for k in range(1000 // args.batch_size):
- # input = quiz_machine.generate_c_quizzes(
- # args.batch_size,
- # model_for_generation=model,
- # procedure=[(("f_B", "f_A", "A", "B"), (1, 1, 1, 1), None)],
- # )
-
- # input = input.to(main_device)
- # targets = input
- # output = model(mygpt.BracketedSequence(input)).x
- # loss = -F.cross_entropy(output.transpose(1, 2), targets)
- # acc_train_loss += loss.item() * input.size(0)
- # nb_train_samples += input.size(0)
-
- # optimizer.zero_grad()
- # loss.backward()
- # optimizer.step()
-
- # log_string(
- # f"increase_entropy {n_epoch} entropy {acc_train_loss/nb_train_samples}"
- # )
-
- exit(0)
-
-######################################################################
-
for n_epoch in range(current_epoch, args.nb_epochs):
state = {
"current_epoch": n_epoch,
- "total_time_training_models": total_time_training_models,
- "total_time_generating_c_quizzes": total_time_generating_c_quizzes,
}
filename = "state.pth"
torch.save(state, os.path.join(args.result_dir, filename))
cta = " ".join([f"{float(m.test_accuracy):.04f}" for m in models])
log_string(f"current_test_accuracies {cta}")
- cta = " ".join([f"{float(m.best_test_accuracy):.04f}" for m in models])
- log_string(f"current_best_test_accuracies {cta}")
-
##################################################
- for model in models:
- if model.test_accuracy >= args.accuracy_to_make_c_quizzes:
- log_string(
- f"storing_best model {model.id} accuracy {model.best_test_accuracy} -> {model.test_accuracy}"
- )
- model.best_dict = copy.deepcopy(model.state_dict())
- model.best_test_accuracy = model.test_accuracy
-
- # we restart
- if total_time_generating_c_quizzes == 0:
- total_time_training_models = 0
-
- if (
- min([m.best_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes
- and total_time_training_models >= total_time_generating_c_quizzes
- ):
- for model in models:
- model.current_dict = copy.deepcopy(model.state_dict())
- model.load_state_dict(model.best_dict)
-
- start_time = time.perf_counter()
+ if min([m.test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
record_new_c_quizzes(
models,
quiz_machine,
args.nb_new_c_quizzes_for_train,
args.nb_new_c_quizzes_for_test,
)
- total_time_generating_c_quizzes += time.perf_counter() - start_time
- # Force one epoch of training
for model in models:
- model.load_state_dict(model.current_dict)
+ new_model = mygpt.MyGPT(
+ vocabulary_size=vocabulary_size,
+ dim_model=args.dim_model,
+ dim_keys=args.dim_keys,
+ dim_hidden=args.dim_hidden,
+ nb_heads=args.nb_heads,
+ nb_blocks=args.nb_blocks,
+ compute_attzero=compute_causal_attzero,
+ dropout=args.dropout,
+ ).to(main_device)
+ model.load_state_dict(new_model.state_dict())
+ model.test_accuracy = 0.0
+ model.best_test_accuracy = 0.0
+ model.best_dict = copy.deepcopy(model.state_dict())
##################################################
# Select, improve, and eval the worst model(s)
- if total_time_training_models <= total_time_generating_c_quizzes:
- ranked_models = sorted(
- models,
- # This ugly recipe will pick the worst if there some below
- # args.accuracy_to_make_c_quizzes or one at random if they
- # are all above
- key=lambda m: float(
- m.test_accuracy
- if m.test_accuracy < args.accuracy_to_make_c_quizzes
- else args.accuracy_to_make_c_quizzes + torch.rand(1).item()
- ),
- )
+ ranked_models = sorted(
+ models,
+ # This ugly recipe will pick the worst if there some below
+ # args.accuracy_to_make_c_quizzes or one at random if they
+ # are all above
+ key=lambda m: float(
+ m.test_accuracy
+ if m.test_accuracy < args.accuracy_to_make_c_quizzes
+ else args.accuracy_to_make_c_quizzes + torch.rand(1).item()
+ ),
+ )
- weakest_models = ranked_models[: len(gpus)]
+ weakest_models = ranked_models[: len(gpus)]
- threads = []
+ threads = []
- start_time = time.perf_counter()
+ start_time = time.perf_counter()
- for gpu, model in zip(gpus, weakest_models):
- log_string(f"training model {model.id}")
+ for gpu, model in zip(gpus, weakest_models):
+ log_string(f"training model {model.id}")
- t = threading.Thread(
- target=one_epoch, daemon=True, args=(model, quiz_machine, gpu)
- )
+ t = threading.Thread(
+ target=one_epoch, daemon=True, args=(model, quiz_machine, gpu)
+ )
- threads.append(t)
+ threads.append(t)
- t.start()
+ t.start()
- for t in threads:
- t.join()
+ for t in threads:
+ t.join()
- total_time_training_models += time.perf_counter() - start_time
+ total_time_training_models += time.perf_counter() - start_time
- for model in weakest_models:
- save_additional_results(n_epoch, model, models, c_quizzes_procedure)
+ for model in weakest_models:
+ save_additional_results(n_epoch, model, models, c_quizzes_procedure)
# Save the models to disk