import mygpt
import sky, grids, quiz_machine
+import threading
+
# world quizzes vs. culture quizzes
######################################################################
######################################################################
parser = argparse.ArgumentParser(
- description="An implementation of GPT with cache.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--seed", type=int, default=0)
-parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
+parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1)
########################################
parser.add_argument("--problem", type=str, default="grids")
-parser.add_argument("--nb_threads", type=int, default=-1)
+parser.add_argument("--nb_threads", type=int, default=1)
+
+parser.add_argument("--nb_gpus", type=int, default=1)
parser.add_argument("--nb_gpts", type=int, default=5)
nb_birds=args.sky_nb_birds,
nb_iterations=args.sky_nb_iterations,
speed=args.sky_speed,
- max_nb_cached_chunks=args.nb_train_samples // 100,
+ max_nb_cached_chunks=args.nb_gpus * args.nb_train_samples // 100,
chunk_size=100,
nb_threads=args.nb_threads,
)
back_accuracy = False
elif args.problem == "grids":
problem = grids.Grids(
- device=device,
- max_nb_cached_chunks=args.nb_train_samples // 100,
+ max_nb_cached_chunks=args.nb_gpus * args.nb_train_samples // 100,
chunk_size=100,
nb_threads=args.nb_threads,
)
######################################################################
-# Compute the entropy of the training tokens
-
-token_count = 0
-for input in quiz_machine.batches(split="train", desc="train-entropy"):
- token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum(
- (0, 1)
- )
-token_probas = token_count / token_count.sum()
-entropy = -torch.xlogy(token_probas, token_probas).sum()
-train_set_perplexity = math.exp(entropy)
######################################################################
-# A bit of paranoia never hurts
-if args.max_percents_of_test_in_train >= 0:
- def subsets_as_tuples(batches, cs):
- s = set()
- for batch in batches:
- for x in batch:
- s.add(tuple([v.item() for v in x]))
- if len(s) == cs:
- yield s
- s = set()
- yield s
+def run_tests(model, quiz_machine, deterministic_synthesis, local_device=None):
+ if local_device is None:
+ local_device = device
- nb_test, nb_in_train = 0, 0
- for test_subset in subsets_as_tuples(
- quiz_machine.batches(split="test", desc="test-check"), 25000
- ):
- in_train = set()
- for train_subset in subsets_as_tuples(
- quiz_machine.batches(split="train", desc="train-check"), 25000
- ):
- in_train.update(test_subset.intersection(train_subset))
- nb_in_train += len(in_train)
- nb_test += len(test_subset)
+ with torch.autograd.no_grad():
+ model.eval().to(local_device)
- log_string(
- f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
- )
+ nb_test_samples, acc_test_loss = 0, 0.0
+ nb_samples_accumulated = 0
- assert (
- nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
- ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
+ for input in quiz_machine.batches(model, split="test"):
+ input = input.to(local_device)
+
+ bs = model(mygpt.BracketedSequence(input))
+ output = bs.x
+
+ loss = F.cross_entropy(output.transpose(1, 2), input)
+
+ acc_test_loss += loss.item() * input.size(0)
+
+ nb_test_samples += input.size(0)
+
+ test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
+
+ log_string(f"test_perplexity {n_epoch} {test_perplexity}")
+
+ model.main_test_accuracy = quiz_machine.produce_results(
+ n_epoch=n_epoch,
+ model=model,
+ result_dir=args.result_dir,
+ deterministic_synthesis=deterministic_synthesis,
+ )
-##############################
+def one_epoch(model, quiz_machine, local_device=None):
+ if local_device is None:
+ local_device = device
-def one_epoch(model, quiz_machine):
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
- model.train()
+ model.to(local_device).train()
nb_train_samples, acc_train_loss = 0, 0.0
- for input in quiz_machine.batches(split="train"):
- input = input.to(device)
+ for input in quiz_machine.batches(model, split="train"):
+ input = input.to(local_device)
if nb_train_samples % args.batch_size == 0:
optimizer.zero_grad()
log_string(f"train_perplexity {n_epoch} {train_perplexity}")
+ run_tests(model, quiz_machine, deterministic_synthesis=False)
-######################################################################
-
-
-def run_tests(model, quiz_machine, deterministic_synthesis):
- with torch.autograd.no_grad():
- model.eval()
-
- nb_test_samples, acc_test_loss = 0, 0.0
- nb_samples_accumulated = 0
-
- for input in quiz_machine.batches(split="test"):
- input = input.to(device)
-
- bs = model(mygpt.BracketedSequence(input))
- output = bs.x
-
- loss = F.cross_entropy(output.transpose(1, 2), input)
-
- acc_test_loss += loss.item() * input.size(0)
-
- nb_test_samples += input.size(0)
-
- test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
-
- log_string(f"test_perplexity {n_epoch} {test_perplexity}")
-
- model.main_test_accuracy = quiz_machine.produce_results(
- n_epoch=n_epoch,
- model=model,
- result_dir=args.result_dir,
- deterministic_synthesis=deterministic_synthesis,
- )
+ model.TRAINING_LOCK.release()
######################################################################
def standard_validity(logproba):
l = logproba.sort(dim=-1).values
- return logical_and(l[0] < math.log(0.5), l[1] > math.log(0.95))
+ return (l[:, 0] < math.log(0.5)) & (l[:, 1] > math.log(0.99))
+ # warnings.warn("TEST!!!", RuntimeWarning)
+ # print(l.exp())
+ # return (l[:, 0] < math.log(0.99))
def valid_c_quizzes(recorded, criteria):
c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
if c_quizzes.size(0) > 0:
- logproba = c_quizzes.new(c_quizzes.size(0), len(models))
- for q, l in zip(
- c_quizzes.split(args.batch_size), logits.split(args.batch_size)
- ):
- for model in models:
- l[model.id] = F.cross_entropy(model(q))
-
+ logproba = quiz_machine.logproba_of_solutions(models, c_quizzes)
for l in logproba:
s = " ".join([str(x.item()) for x in l])
logp_file.write(s + "\n")
-
quizzes_and_logproba_records.append((c_quizzes, logproba))
nb_validated = valid_c_quizzes(
# ------------------------------------------------------------
- standard_validity = lambda nb_correct: torch.logical_and(
- nb_correct >= args.min_to_validate, nb_correct <= args.max_to_validate
+ standard_validity = lambda nb_correct: (nb_correct >= args.min_to_validate) & (
+ nb_correct <= args.max_to_validate
)
file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat")
models = []
for k in range(args.nb_gpts):
+ log_string(f"creating model {k} and its w_quizzes")
model = mygpt.MyGPT(
vocabulary_size=vocabulary_size,
dim_model=args.dim_model,
model.main_test_accuracy = 0.0
model.id = k
+ model.TRAINING_LOCK = threading.Lock()
+
+ model.train_w_quizzes = quiz_machine.generate_token_sequences(
+ args.nb_train_samples
+ ).to(device)
+ quiz_machine.reverse_random_half_in_place(model.train_w_quizzes)
+ model.test_w_quizzes = quiz_machine.generate_token_sequences(
+ args.nb_test_samples
+ ).to(device)
+ quiz_machine.reverse_random_half_in_place(model.test_w_quizzes)
models.append(model)
######################################################################
+# Compute the entropy of the training tokens
+
+token_count = 0
+for input in quiz_machine.batches(models[0], split="train", desc="train-entropy"):
+ token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum(
+ (0, 1)
+ )
+token_probas = token_count / token_count.sum()
+entropy = -torch.xlogy(token_probas, token_probas).sum()
+train_set_perplexity = math.exp(entropy)
+
+######################################################################
+# A bit of paranoia never hurts
+
+if args.max_percents_of_test_in_train >= 0:
+
+ def subsets_as_tuples(batches, cs):
+ s = set()
+ for batch in batches:
+ for x in batch:
+ s.add(tuple([v.item() for v in x]))
+ if len(s) == cs:
+ yield s
+ s = set()
+ yield s
+
+ nb_test, nb_in_train = 0, 0
+ for test_subset in subsets_as_tuples(
+ quiz_machine.batches(models[0], split="test", desc="test-check"), 25000
+ ):
+ in_train = set()
+ for train_subset in subsets_as_tuples(
+ quiz_machine.batches(models[0], split="train", desc="train-check"), 25000
+ ):
+ in_train.update(test_subset.intersection(train_subset))
+ nb_in_train += len(in_train)
+ nb_test += len(test_subset)
+
+ log_string(
+ f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
+ )
+
+ assert (
+ nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
+ ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
+
+######################################################################
+
nb_new_c_quizzes_for_train = args.nb_train_samples // 50
nb_new_c_quizzes_for_test = args.nb_test_samples // 50
##################################################
# Select, improve, and eval the worst model
- weakest_model = min(models, key=lambda m: float(m.main_test_accuracy))
+ ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy))
- log_string(
- f"training model {weakest_model.id} main_test_accuracy {weakest_model.main_test_accuracy}"
- )
+ weakest_models = ranked_models[: args.nb_gpus]
- one_epoch(weakest_model, quiz_machine)
+ for gpu_id, model in enumerate(weakest_models):
+ model.TRAINING_LOCK.acquire()
- log_string(
- f"train_set_composition w_quizzes {quiz_machine.nb_batch_w_quizzes} c_quizzes {quiz_machine.nb_batch_c_quizzes}"
- )
+ log_string(
+ f"training model {model.id} main_test_accuracy {model.main_test_accuracy}"
+ )
- run_tests(weakest_model, quiz_machine, deterministic_synthesis=False)
+ threading.Thread(
+ target=one_epoch, daemon=True, args=(model, quiz_machine, f"cuda:{gpu_id}")
+ ).start()
- log_string(
- f"test_set_composition w_quizzes {quiz_machine.nb_batch_w_quizzes} c_quizzes {quiz_machine.nb_batch_c_quizzes}"
- )
+ for model in weakest_models:
+ model.TRAINING_LOCK.acquire()
+ model.TRAINING_LOCK.release()
##################################################
# Replace a fraction of the w_quizzes with fresh ones
- quiz_machine.renew_w_quizzes(args.nb_train_samples // args.nb_gpts)
+ log_string(
+ f"cache_w_quizzes contains {quiz_machine.problem.nb_cached_quizzes()} quizzes"
+ )
+
+ # Renew entirely the train set
+
+ for model in weakest_models:
+ quiz_machine.renew_w_quizzes(model, args.nb_train_samples)
##################################################
# If all the models are good enough, generate new quizzes and
nb_for_test=nb_new_c_quizzes_for_test,
)
- for model in models:
- run_tests(model, quiz_machine, deterministic_synthesis=False)
-
-
######################################################################