From 3ec49d885b2e0b71b0c44e0957709a20115ac828 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 11 Jul 2024 15:31:56 +0200 Subject: [PATCH] Update. --- main.py | 377 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 188 insertions(+), 189 deletions(-) diff --git a/main.py b/main.py index 5956be5..0a266a8 100755 --- a/main.py +++ b/main.py @@ -16,12 +16,182 @@ import ffutils import mygpt import sky, grids, quiz_machine -import threading +import torch.multiprocessing as mp + +# mp.set_start_method('spawn') # world quizzes vs. culture quizzes ###################################################################### + +def log_string(s): + t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime()) + + if log_file is not None: + log_file.write(t + s + "\n") + log_file.flush() + + print(t + s) + sys.stdout.flush() + + +###################################################################### + + +def run_tests(model, quiz_machine, deterministic_synthesis, local_device=None): + if local_device is None: + local_device = device + + with torch.autograd.no_grad(): + model.eval().to(local_device) + + nb_test_samples, acc_test_loss = 0, 0.0 + nb_samples_accumulated = 0 + + for input in quiz_machine.batches(model, split="test"): + input = input.to(local_device) + + bs = model(mygpt.BracketedSequence(input)) + output = bs.x + + loss = F.cross_entropy(output.transpose(1, 2), input) + + acc_test_loss += loss.item() * input.size(0) + + nb_test_samples += input.size(0) + + test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) + + log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}") + + model.main_test_accuracy = quiz_machine.produce_results( + n_epoch=n_epoch, + model=model, + result_dir=args.result_dir, + deterministic_synthesis=deterministic_synthesis, + ) + + +def one_epoch(model, quiz_machine, local_device=None): + if local_device is None: + local_device = device + + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + + model.to(local_device).train() + + nb_train_samples, acc_train_loss = 0, 0.0 + + for input in quiz_machine.batches(model, split="train"): + input = input.to(local_device) + + if nb_train_samples % args.batch_size == 0: + optimizer.zero_grad() + + output = model(mygpt.BracketedSequence(input)).x + loss = F.cross_entropy(output.transpose(1, 2), input) + acc_train_loss += loss.item() * input.size(0) + + nb_train_samples += input.size(0) + + loss.backward() + + if nb_train_samples % args.batch_size == 0: + optimizer.step() + + train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) + + log_string(f"train_perplexity {n_epoch} model {model.id} {train_perplexity}") + + run_tests(model, quiz_machine, deterministic_synthesis=False) + + +###################################################################### + + +def standard_validity(logproba): + l = logproba.sort(dim=-1).values + return (l[:, 0] < math.log(0.5)) & (l[:, 1] > math.log(0.99)) + # warnings.warn("TEST!!!", RuntimeWarning) + # print(l.exp()) + # return (l[:, 0] < math.log(0.99)) + + +def valid_c_quizzes(recorded, criteria): + result = [q[criteria(lp)] for q, lp in recorded] + return torch.cat(result, dim=0) if len(result) > 0 else torch.tensor([]) + + +###################################################################### + + +def create_c_quizzes( + models, + quiz_machine, + nb_for_train=1000, + nb_for_test=100, +): + quizzes_and_logproba_records = [] + + nb_to_create = nb_for_train + nb_for_test + + # ------------------------------------------------------------ + + file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat") + + with open(file_name, "w") as logp_file: + while ( + valid_c_quizzes(quizzes_and_logproba_records, standard_validity).size(0) + < nb_to_create + ): + # Select a model at random to generate the new quizzes + + model_for_generation = models[torch.randint(len(models), (1,))] + + c_quizzes = quiz_machine.generate_quizzes( + nb_to_create, + model_for_generation=model_for_generation, + temperature=args.generation_temperature, + ) + + c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] + + if c_quizzes.size(0) > 0: + logproba = quiz_machine.logproba_of_solutions(models, c_quizzes) + for l in logproba: + s = " ".join([str(x.item()) for x in l]) + logp_file.write(s + "\n") + quizzes_and_logproba_records.append((c_quizzes, logproba)) + + nb_validated = valid_c_quizzes( + quizzes_and_logproba_records, standard_validity + ).size(0) + + log_string( + f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create}" + ) + + # store the new c_quizzes which have been validated + + new_c_quizzes = valid_c_quizzes(quizzes_and_logproba_records, standard_validity) + + quiz_machine.reverse_random_half_in_place(new_c_quizzes) + + quiz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True) + quiz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False) + + # save a bunch of images to investigate what quizzes with a + # certain nb of correct predictions look like + + q = new_c_quizzes[:72] + + if q.size(0) > 0: + quiz_machine.save_quizzes(args.result_dir, f"culture_c_quiz_{n_epoch:04d}", q) + + +###################################################################### + if torch.cuda.is_available(): device = torch.device("cuda") torch.backends.cuda.matmul.allow_tf32 = True @@ -189,6 +359,11 @@ except FileExistsError: log_file = open(os.path.join(args.result_dir, args.log_filename), "a") +log_string(f"argv {' '.join(sys.argv)}") + +for n in vars(args): + log_string(f"args.{n} {getattr(args, n)}") + if args.seed >= 0: # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False @@ -199,26 +374,6 @@ if args.seed >= 0: ###################################################################### - -def log_string(s): - t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime()) - - if log_file is not None: - log_file.write(t + s + "\n") - log_file.flush() - - print(t + s) - sys.stdout.flush() - - -log_string(f"argv {' '.join(sys.argv)}") - -for n in vars(args): - log_string(f"args.{n} {getattr(args, n)}") - - -###################################################################### - if args.dirty_debug: args.nb_train_samples = 2500 args.nb_test_samples = 100 @@ -276,165 +431,6 @@ log_string(f"vocabulary_size {vocabulary_size}") ###################################################################### - -###################################################################### - - -def run_tests(model, quiz_machine, deterministic_synthesis, local_device=None): - if local_device is None: - local_device = device - - with torch.autograd.no_grad(): - model.eval().to(local_device) - - nb_test_samples, acc_test_loss = 0, 0.0 - nb_samples_accumulated = 0 - - for input in quiz_machine.batches(model, split="test"): - input = input.to(local_device) - - bs = model(mygpt.BracketedSequence(input)) - output = bs.x - - loss = F.cross_entropy(output.transpose(1, 2), input) - - acc_test_loss += loss.item() * input.size(0) - - nb_test_samples += input.size(0) - - test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) - - log_string(f"test_perplexity {n_epoch} {test_perplexity}") - - model.main_test_accuracy = quiz_machine.produce_results( - n_epoch=n_epoch, - model=model, - result_dir=args.result_dir, - deterministic_synthesis=deterministic_synthesis, - ) - - -def one_epoch(model, quiz_machine, local_device=None): - if local_device is None: - local_device = device - - optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) - - model.to(local_device).train() - - nb_train_samples, acc_train_loss = 0, 0.0 - - for input in quiz_machine.batches(model, split="train"): - input = input.to(local_device) - - if nb_train_samples % args.batch_size == 0: - optimizer.zero_grad() - - output = model(mygpt.BracketedSequence(input)).x - loss = F.cross_entropy(output.transpose(1, 2), input) - acc_train_loss += loss.item() * input.size(0) - - nb_train_samples += input.size(0) - - loss.backward() - - if nb_train_samples % args.batch_size == 0: - optimizer.step() - - train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) - - log_string(f"train_perplexity {n_epoch} {train_perplexity}") - - run_tests(model, quiz_machine, deterministic_synthesis=False) - - model.TRAINING_LOCK.release() - - -###################################################################### - - -def standard_validity(logproba): - l = logproba.sort(dim=-1).values - return (l[:, 0] < math.log(0.5)) & (l[:, 1] > math.log(0.99)) - # warnings.warn("TEST!!!", RuntimeWarning) - # print(l.exp()) - # return (l[:, 0] < math.log(0.99)) - - -def valid_c_quizzes(recorded, criteria): - result = [q[criteria(lp)] for q, lp in recorded] - return torch.cat(result, dim=0) if len(result) > 0 else torch.tensor([]) - - -###################################################################### - - -def create_c_quizzes( - models, - quiz_machine, - nb_for_train=1000, - nb_for_test=100, -): - quizzes_and_logproba_records = [] - - nb_to_create = nb_for_train + nb_for_test - - # ------------------------------------------------------------ - - file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat") - - with open(file_name, "w") as logp_file: - while ( - valid_c_quizzes(quizzes_and_logproba_records, standard_validity).size(0) - < nb_to_create - ): - # Select a model at random to generate the new quizzes - - model_for_generation = models[torch.randint(len(models), (1,))] - - c_quizzes = quiz_machine.generate_quizzes( - nb_to_create, - model_for_generation=model_for_generation, - temperature=args.generation_temperature, - ) - - c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] - - if c_quizzes.size(0) > 0: - logproba = quiz_machine.logproba_of_solutions(models, c_quizzes) - for l in logproba: - s = " ".join([str(x.item()) for x in l]) - logp_file.write(s + "\n") - quizzes_and_logproba_records.append((c_quizzes, logproba)) - - nb_validated = valid_c_quizzes( - quizzes_and_logproba_records, standard_validity - ).size(0) - - log_string( - f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create}" - ) - - # store the new c_quizzes which have been validated - - new_c_quizzes = valid_c_quizzes(quizzes_and_logproba_records, standard_validity) - - quiz_machine.reverse_random_half_in_place(new_c_quizzes) - - quiz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True) - quiz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False) - - # save a bunch of images to investigate what quizzes with a - # certain nb of correct predictions look like - - q = new_c_quizzes[:72] - - if q.size(0) > 0: - quiz_machine.save_quizzes(args.result_dir, f"culture_c_quiz_{n_epoch:04d}", q) - - -###################################################################### - models = [] for k in range(args.nb_gpts): @@ -452,7 +448,6 @@ for k in range(args.nb_gpts): model.main_test_accuracy = 0.0 model.id = k - model.TRAINING_LOCK = threading.Lock() model.train_w_quizzes = quiz_machine.generate_token_sequences( args.nb_train_samples @@ -554,20 +549,24 @@ for n_epoch in range(args.nb_epochs): weakest_models = ranked_models[: args.nb_gpus] - for gpu_id, model in enumerate(weakest_models): - model.TRAINING_LOCK.acquire() + processes = [] + for gpu_id, model in enumerate(weakest_models): log_string( f"training model {model.id} main_test_accuracy {model.main_test_accuracy}" ) - threading.Thread( - target=one_epoch, daemon=True, args=(model, quiz_machine, f"cuda:{gpu_id}") - ).start() + process = mp.Process( + target=one_epoch, args=(model, quiz_machine, f"cuda:{gpu_id}") + ) - for model in weakest_models: - model.TRAINING_LOCK.acquire() - model.TRAINING_LOCK.release() + processes.append(process) + + for process in processes: + process.start() + + for process in processes: + process.join() ################################################## # Renew the train sets -- 2.39.5