######################################################################
+def generate_c_quizz_with_generator(generator, quiz_machine, nb):
+ generator.to(main_device)
+
+ c_quizzes = quiz_machine.problem.create_empty_quizzes(
+ nb, struct=("A", "f_A", "B", "f_B")
+ )
+
+ i = F.one_hot(
+ torch.randint(args.nb_gpts, (c_quizzes.size(0),)),
+ num_classes=args.nb_gpts,
+ )
+
+ prolog = token_prolog_0 * i + token_prolog_2 * (1 - i)
+ len_prolog, len_quiz = prolog.size(1), c_quizzes.size(1)
+
+ prologued_c_quizzes = torch.cat([prolog, c_quizzes], dim=1).to(main_device)
+
+ T = torch.arange(prologued_c_quizzes.size(1), device=prologued_c_quizzes.device)[
+ None, :
+ ]
+
+ ar_mask = ((T >= len_prolog) & ((T - len_prolog) % (len_quiz // 4) > 0)).long()
+
+ seq_logproba = torch.zeros(
+ prologued_c_quizzes.size(0), device=prologued_c_quizzes.device
+ )
+
+ with torch.autograd.no_grad():
+ t = generator.training
+ generator.eval()
+
+ one_batch_masked_inplace_autoregression(
+ generator,
+ prologued_c_quizzes,
+ ar_mask,
+ seq_logproba,
+ deterministic_synthesis=False,
+ )
+
+ generator.train(t)
+
+ prologued_c_quizzes = (
+ prologued_c_quizzes * (prologued_c_quizzes < vocabulary_size).long()
+ )
+
+ return prologued_c_quizzes[:, len_prolog:].to("cpu")
+
+
+def batches_for_generator(generator, quiz_machine, models, w_quizzes=True):
+ samples = []
+
+ for _ in range(args.nb_train_samples // args.batch_size):
+ while sum([x.size(0) for x in samples]) < args.batch_size:
+ # Generate a bunch of quizzes
+
+ if w_quizzes:
+ # Either we start with the world quizzes
+ c_quizzes = quiz_machine.problem.generate_w_quizzes(
+ args.batch_size, progress_bar=False
+ )
+ else:
+ # Or we use the generator itself to generate them
+ c_quizzes = generate_c_quizz_with_generator(
+ args.batch_size, generator, quiz_machine
+ )
+
+ # We remove the trivial ones
+ to_keep = quiz_machine.problem.trivial(c_quizzes) == False
+ c_quizzes = c_quizzes[to_keep]
+
+ # If there are remaining ones, we compute the true prolog
+ # that indicates how the GPTs solve it
+
+ if c_quizzes.size(0) > 0:
+ seq_logproba = quiz_machine.models_logprobas(
+ models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+ ) + quiz_machine.models_logprobas(
+ models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
+ )
+
+ probas = seq_logproba.exp()
+
+ nu = probas <= args.proba_not_understands
+ u = probas >= args.proba_understands
+
+ prolog = (
+ (nu.long() * token_prolog_0)
+ + (((nu == False) & (u == False)).long() * token_prolog_1)
+ + (u.long() * token_prolog_2)
+ )
+
+ prologued_c_quizzes = torch.cat([prolog, c_quizzes], dim=1)
+
+ # nb_u = u.long().sum(dim=1)
+ # nb_nu = nu.long().sum(dim=1)
+
+ # prologued_c_quizzes = prologued_c_quizzes[
+ # (nb_u + nb_nu == args.nb_gpts)
+ # & (nb_nu >= 1)
+ # & (nb_nu <= args.max_fail_to_validate)
+ # ]
+
+ samples.append(prologued_c_quizzes)
+
+ # Now we yield a batch
+
+ x = torch.cat(samples, dim=0)
+ samples = [x[args.batch_size :]]
+
+ yield x[: args.batch_size]
+
+
+def one_generator_epoch(
+ generator, quiz_machine, models, w_quizzes=True, local_device=main_device
+):
+ model.to(local_device).train()
+
+ optimizer = torch.optim.Adam(generator.parameters(), lr=args.learning_rate)
+
+ nb_train_samples, acc_train_loss = 0, 0.0
+
+ hard_w_quizzes = []
+
+ src = batches_for_generator(
+ generator=generator, quiz_machine=quiz_machine, models=models
+ )
+
+ for input in tqdm.tqdm(
+ src,
+ dynamic_ncols=True,
+ desc="training",
+ total=args.nb_train_samples // args.batch_size,
+ ):
+ input = input.to(local_device)
+
+ if nb_train_samples % args.batch_size == 0:
+ optimizer.zero_grad()
+
+ targets = input
+
+ output = generator(mygpt.BracketedSequence(input)).x
+ loss = F.cross_entropy(output.transpose(1, 2), targets)
+ acc_train_loss += loss.item() * input.size(0)
+ nb_train_samples += input.size(0)
+
+ loss.backward()
+
+ if nb_train_samples % args.batch_size == 0:
+ optimizer.step()
+
+ train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+
+ log_string(f"train_perplexity {n_epoch} generator - {train_perplexity}")
+
+ generator.to(main_device)
+
+
+######################################################################
+
+
def train_complexifier(model_gen, model_pred1, model_pred2):
samples = []
perf = []
######################################################################
-token_prolog_0 = vocabulary_size + 0
-token_prolog_1 = vocabulary_size + 1
-token_prolog_2 = vocabulary_size + 2
-generator_vocabulary_size = vocabulary_size + 3
-
-generator = mygpt.MyGPT(
- vocabulary_size=generator_vocabulary_size,
- dim_model=args.dim_model,
- dim_keys=args.dim_keys,
- dim_hidden=args.dim_hidden,
- nb_heads=args.nb_heads,
- nb_blocks=args.nb_blocks,
- causal=True,
- dropout=args.dropout,
-).to(main_device)
-
-generator.main_test_accuracy = 0.0
-
-
-######################################################################
-
-
-def generate_c_quizz_with_generator(generator, quiz_machine):
- c_quizzes = quiz_machine.problem.create_empty_quizzes(
- args.batch_size, struct=("A", "f_A", "B", "f_B")
- )
- i = F.one_hot(
- torch.randint(args.nb_gpts, (c_quizzes.size(0),)),
- num_classes=args.nb_gpts,
- )
- prolog = token_prolog_0 * i + token_prolog_2 * (1 - i)
- c_quizzes = torch.cat([prolog, c_quizzes], dim=1)
- ar_mask = (
- torch.arange(c_quizzes.size(1), device=c_quizzes.device)[None, :]
- >= args.nb_gpts
- ).long()
-
- one_batch_masked_inplace_autoregression(
- generator,
- c_quizzes,
- ar_mask,
- seq_logproba,
- deterministic_synthesis=False,
- )
-
- return c_quizzes[:, args.nb_gpts :]
-
-
-def batches_for_generator(generator=None, quiz_machine=None, device=main_device):
- samples = []
-
- for _ in range(args.nb_train_samples // args.batch_size):
- while sum([x.size(0) for x in samples]) < args.batch_size:
- # Generate a bunch of quizzes
-
- if generator is None:
- # Either we start with the world quizzes
- c_quizzes = quiz_machine.problem.generate_w_quizzes(args.batch_size)
- else:
- # Or we use the generator itself to generate them
- c_quizzes = generate_c_quizz_with_generator(generator, quiz_machine)
-
- # We remove the trivial ones
- to_keep = quiz_machine.problem.trivial(c_quizzes) == False
- c_quizzes = c_quizzes[to_keep]
-
- # If there are remaining ones, we compute the true prolog
- # that indicates how the GPTs solve it
-
- if c_quizzes.size(0) > 0:
- seq_logproba = quiz_machine.models_logprobas(
- models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
- ) + quiz_machine.models_logprobas(
- models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
- )
-
- probas = seq_logproba.exp()
-
- nu = probas <= args.proba_not_understands
- u = probas >= args.proba_understands
-
- prolog = (
- (nu.long() * token_prolog_0)
- + (u.long() * token_prolog_2)
- + ((nu == False & u == False).long() * token_prolog_1)
- )
-
- samples.append(torch.cat([prolog, c_quizzes], dim=1))
-
- # Now we yield a batch
-
- x = torch.cat(samples, dim=0)
- samples = [x[args.batch_size :]]
-
- yield x[: args.batch_size]
-
-
-def one_generator_epoch(
- generator, quiz_machine=None, models=None, local_device=main_device
-):
- model.to(local_device).train()
-
- optimizer = torch.optim.Adam(generator.parameters(), lr=args.learning_rate)
-
- nb_train_samples, acc_train_loss = 0, 0.0
-
- hard_w_quizzes = []
-
- full_input, full_from_w = quiz_machine.data_input(generator, split="train")
- src = zip(full_input.split(args.batch_size), full_from_w.split(args.batch_size))
-
- for input, from_w in tqdm.tqdm(
- src,
- dynamic_ncols=True,
- desc="training",
- total=full_input.size(0) // args.batch_size,
- ):
- input = input.to(local_device)
-
- if nb_train_samples % args.batch_size == 0:
- optimizer.zero_grad()
-
- targets = input
-
- output = generator(mygpt.BracketedSequence(input)).x
- loss_per_token = F.cross_entropy(
- output.transpose(1, 2), targets, reduction="none"
- )
- loss = loss_per_token.mean()
- acc_train_loss += loss.item() * input.size(0)
-
- loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1)
- if from_w.any():
- hard_w_quizzes.append(
- (input[from_w].to("cpu"), loss_per_samples[from_w].to("cpu"))
- )
-
- nb_train_samples += input.size(0)
-
- loss.backward()
-
- if nb_train_samples % args.batch_size == 0:
- optimizer.step()
-
- train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
-
- log_string(
- f"train_perplexity {n_epoch} generator {generator.id} {train_perplexity}"
- )
-
- run_tests(generator, quiz_machine)
-
- threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values
- threshold = threshold[threshold.size(0) // 2]
-
- generator.hard_w_quizzes = torch.cat(
- [x[l >= threshold] for x, l in hard_w_quizzes], dim=0
- )
-
- generator.to(main_device)
-
-
-######################################################################
-
current_epoch = 0
if args.resume:
# exit(0)
+######################################################################
+
+token_prolog_0 = vocabulary_size + 0
+token_prolog_1 = vocabulary_size + 1
+token_prolog_2 = vocabulary_size + 2
+generator_vocabulary_size = vocabulary_size + 3
+
+generator = mygpt.MyGPT(
+ vocabulary_size=generator_vocabulary_size,
+ dim_model=args.dim_model,
+ dim_keys=args.dim_keys,
+ dim_hidden=args.dim_hidden,
+ nb_heads=args.nb_heads,
+ nb_blocks=args.nb_blocks,
+ causal=True,
+ dropout=args.dropout,
+).to(main_device)
+
+generator.main_test_accuracy = 0.0
+
+for n_epoch in range(25):
+ one_generator_epoch(
+ generator,
+ quiz_machine=quiz_machine,
+ models=models,
+ w_quizzes=True,
+ local_device=main_device,
+ )
+
+ c_quizzes = generate_c_quizz_with_generator(
+ generator, quiz_machine, args.batch_size
+ )
+
+ seq_logproba = quiz_machine.models_logprobas(
+ models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+ ) + quiz_machine.models_logprobas(
+ models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
+ )
+
+ print(seq_logproba.exp())
+
+
+one_generator_epoch(
+ generator,
+ quiz_machine=quiz_machine,
+ models=models,
+ w_quizzes=False,
+ local_device=main_device,
+)
+
+exit(0)
+
+
######################################################################
for n_epoch in range(current_epoch, args.nb_epochs):