import mygpt
import sky, grids, quiz_machine
+from quiz_machine import one_batch_masked_inplace_autoregression
+
import threading, subprocess
import torch.multiprocessing as mp
######################################################################
-def train_autoencoder():
+models = []
+
+for k in range(args.nb_gpts):
+ log_string(f"creating model {k} and its w_quizzes")
+
model = mygpt.MyGPT(
vocabulary_size=vocabulary_size,
dim_model=args.dim_model,
dim_hidden=args.dim_hidden,
nb_heads=args.nb_heads,
nb_blocks=args.nb_blocks,
- causal=False,
+ causal=True,
dropout=args.dropout,
- autoencoder_dim=args.autoencoder_dim,
).to(main_device)
- test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples)
+ model.main_test_accuracy = 0.0
+ model.id = k
- optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+ model.train_w_quizzes = quiz_machine.problem.generate_w_quizzes(
+ args.nb_train_samples
+ )
- nb_train_samples, acc_train_loss = 0, 0.0
+ model.test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples)
- for n_epoch in range(args.nb_epochs):
- train_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples)
- for input in tqdm.tqdm(
- train_w_quizzes.split(args.batch_size),
- dynamic_ncols=True,
- desc="training AE",
- total=train_w_quizzes.size(0) // args.batch_size,
- ):
- model.train()
- l = input.size(1) // 4
- input = input[:, -l:].to(main_device)
+ models.append(model)
- if nb_train_samples % args.batch_size == 0:
- optimizer.zero_grad()
+######################################################################
- z_shape = model.encode(mygpt.BracketedSequence(input.to(main_device)))
- output = model.decode(z_shape).x
- loss = F.cross_entropy(output.transpose(1, 2), input)
- acc_train_loss += loss.item() * input.size(0)
+token_prolog_0 = vocabulary_size + 0
+token_prolog_1 = vocabulary_size + 1
+token_prolog_2 = vocabulary_size + 2
+generator_vocabulary_size = vocabulary_size + 3
- nb_train_samples += input.size(0)
+generator = mygpt.MyGPT(
+ vocabulary_size=generator_vocabulary_size,
+ dim_model=args.dim_model,
+ dim_keys=args.dim_keys,
+ dim_hidden=args.dim_hidden,
+ nb_heads=args.nb_heads,
+ nb_blocks=args.nb_blocks,
+ causal=True,
+ dropout=args.dropout,
+).to(main_device)
- loss.backward()
+generator.main_test_accuracy = 0.0
- if nb_train_samples % args.batch_size == 0:
- optimizer.step()
- train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+######################################################################
- log_string(f"train_perplexity {n_epoch} model ae {train_perplexity}")
- filename = f"autoencoder.pth"
- torch.save(
- model.state_dict(),
- os.path.join(args.result_dir, filename),
- )
- log_string(f"wrote {filename}")
+def generate_c_quizz_with_generator(generator, quiz_machine):
+ c_quizzes = quiz_machine.problem.create_empty_quizzes(
+ args.batch_size, struct=("A", "f_A", "B", "f_B")
+ )
+ i = F.one_hot(
+ torch.randint(args.nb_gpts, (c_quizzes.size(0),)),
+ num_classes=args.nb_gpts,
+ )
+ prolog = token_prolog_0 * i + token_prolog_2 * (1 - i)
+ c_quizzes = torch.cat([prolog, c_quizzes], dim=1)
+ ar_mask = (
+ torch.arange(c_quizzes.size(1), device=c_quizzes.device)[None, :]
+ >= args.nb_gpts
+ ).long()
+
+ one_batch_masked_inplace_autoregression(
+ generator,
+ c_quizzes,
+ ar_mask,
+ seq_logproba,
+ deterministic_synthesis=False,
+ )
- with torch.autograd.no_grad():
- model.eval()
- input = test_w_quizzes[0 * 128 : 1 * 128, -l:]
- z_shape = model.encode(mygpt.BracketedSequence(input.to(main_device)))
- logits = model.decode(z_shape).x
+ return c_quizzes[:, args.nb_gpts :]
- # dist = torch.distributions.categorical.Categorical(logits=logits)
- # q = dist.sample()
- q = logits.argmax(dim=-1)
- q = q.reshape(q.size(0) // 2, 2, -1)
- input = input.reshape(input.size(0) // 2, 2, -1)
- q = torch.cat([input.to("cpu"), q.to("cpu")], dim=1).reshape(q.size(0), -1)
- quiz_machine.problem.save_quizzes_as_image(
- args.result_dir,
- f"culture_ae_{n_epoch:04d}.png",
- q,
- )
+def batches_for_generator(generator=None, quiz_machine=None, device=main_device):
+ samples = []
- input1 = test_w_quizzes[1 * 128 : 2 * 128, -l:]
- input2 = test_w_quizzes[2 * 128 : 3 * 128, -l:]
- z_shape1 = model.encode(mygpt.BracketedSequence(input1.to(main_device)))
- z_shape2 = model.encode(mygpt.BracketedSequence(input2.to(main_device)))
- z_shape = ((z_shape1[0] + z_shape2[0]) * 0.5, z_shape1[1])
- logits = model.decode(z_shape).x
+ for _ in range(args.nb_train_samples // args.batch_size):
+ while sum([x.size(0) for x in samples]) < args.batch_size:
+ # Generate a bunch of quizzes
- q = logits.argmax(dim=-1)
- # q = q.reshape(q.size(0) // 2, 2, -1)
- # input = input.reshape(input.size(0) // 2, 2, -1)
- # q = torch.cat([input.to("cpu"), q.to("cpu")], dim=1).reshape(q.size(0), -1)
+ if generator is None:
+ # Either we start with the world quizzes
+ c_quizzes = quiz_machine.problem.generate_w_quizzes(args.batch_size)
+ else:
+ # Or we use the generator itself to generate them
+ c_quizzes = generate_c_quizz_with_generator(generator, quiz_machine)
- q = q.reshape(q.size(0) // 4, -1)
+ # We remove the trivial ones
+ to_keep = quiz_machine.problem.trivial(c_quizzes) == False
+ c_quizzes = c_quizzes[to_keep]
- quiz_machine.problem.save_quizzes_as_image(
- args.result_dir,
- f"culture_mix_ae_{n_epoch:04d}.png",
- q,
- )
+ # If there are remaining ones, we compute the true prolog
+ # that indicates how the GPTs solve it
- return model
+ if c_quizzes.size(0) > 0:
+ seq_logproba = quiz_machine.models_logprobas(
+ models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+ ) + quiz_machine.models_logprobas(
+ models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
+ )
+ probas = seq_logproba.exp()
-# if args.autoencoder_dim > 0:
-# ae = train_autoencoder()
-# exit(0)
+ nu = probas <= args.proba_not_understands
+ u = probas >= args.proba_understands
-######################################################################
+ prolog = (
+ (nu.long() * token_prolog_0)
+ + (u.long() * token_prolog_2)
+ + ((nu == False & u == False).long() * token_prolog_1)
+ )
+ samples.append(torch.cat([prolog, c_quizzes], dim=1))
-models = []
+ # Now we yield a batch
-for k in range(args.nb_gpts):
- log_string(f"creating model {k} and its w_quizzes")
+ x = torch.cat(samples, dim=0)
+ samples = [x[args.batch_size :]]
- model = mygpt.MyGPT(
- vocabulary_size=vocabulary_size,
- dim_model=args.dim_model,
- dim_keys=args.dim_keys,
- dim_hidden=args.dim_hidden,
- nb_heads=args.nb_heads,
- nb_blocks=args.nb_blocks,
- causal=True,
- dropout=args.dropout,
- ).to(main_device)
+ yield x[: args.batch_size]
- model.main_test_accuracy = 0.0
- model.id = k
- model.train_w_quizzes = quiz_machine.problem.generate_w_quizzes(
- args.nb_train_samples
+def one_generator_epoch(
+ generator, quiz_machine=None, models=None, local_device=main_device
+):
+ model.to(local_device).train()
+
+ optimizer = torch.optim.Adam(generator.parameters(), lr=args.learning_rate)
+
+ nb_train_samples, acc_train_loss = 0, 0.0
+
+ hard_w_quizzes = []
+
+ full_input, full_from_w = quiz_machine.data_input(generator, split="train")
+ src = zip(full_input.split(args.batch_size), full_from_w.split(args.batch_size))
+
+ for input, from_w in tqdm.tqdm(
+ src,
+ dynamic_ncols=True,
+ desc="training",
+ total=full_input.size(0) // args.batch_size,
+ ):
+ input = input.to(local_device)
+
+ if nb_train_samples % args.batch_size == 0:
+ optimizer.zero_grad()
+
+ targets = input
+
+ output = generator(mygpt.BracketedSequence(input)).x
+ loss_per_token = F.cross_entropy(
+ output.transpose(1, 2), targets, reduction="none"
+ )
+ loss = loss_per_token.mean()
+ acc_train_loss += loss.item() * input.size(0)
+
+ loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1)
+ if from_w.any():
+ hard_w_quizzes.append(
+ (input[from_w].to("cpu"), loss_per_samples[from_w].to("cpu"))
+ )
+
+ nb_train_samples += input.size(0)
+
+ loss.backward()
+
+ if nb_train_samples % args.batch_size == 0:
+ optimizer.step()
+
+ train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+
+ log_string(
+ f"train_perplexity {n_epoch} generator {generator.id} {train_perplexity}"
)
- model.test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples)
+ run_tests(generator, quiz_machine)
+
+ threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values
+ threshold = threshold[threshold.size(0) // 2]
+
+ generator.hard_w_quizzes = torch.cat(
+ [x[l >= threshold] for x, l in hard_w_quizzes], dim=0
+ )
+
+ generator.to(main_device)
- models.append(model)
######################################################################