X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=fc55b9ce7318c45a557488bbb31d81a8d6f07f3d;hb=f2ab5fd489adebe9b34ac825d39e41f13f287cb2;hp=0a266a8283db1f786de86c9984f819d909b1fa75;hpb=3ec49d885b2e0b71b0c44e0957709a20115ac828;p=culture.git diff --git a/main.py b/main.py index 0a266a8..fc55b9c 100755 --- a/main.py +++ b/main.py @@ -16,182 +16,14 @@ import ffutils import mygpt import sky, grids, quiz_machine -import torch.multiprocessing as mp +import threading -# mp.set_start_method('spawn') +import torch.multiprocessing as mp # world quizzes vs. culture quizzes ###################################################################### - -def log_string(s): - t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime()) - - if log_file is not None: - log_file.write(t + s + "\n") - log_file.flush() - - print(t + s) - sys.stdout.flush() - - -###################################################################### - - -def run_tests(model, quiz_machine, deterministic_synthesis, local_device=None): - if local_device is None: - local_device = device - - with torch.autograd.no_grad(): - model.eval().to(local_device) - - nb_test_samples, acc_test_loss = 0, 0.0 - nb_samples_accumulated = 0 - - for input in quiz_machine.batches(model, split="test"): - input = input.to(local_device) - - bs = model(mygpt.BracketedSequence(input)) - output = bs.x - - loss = F.cross_entropy(output.transpose(1, 2), input) - - acc_test_loss += loss.item() * input.size(0) - - nb_test_samples += input.size(0) - - test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) - - log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}") - - model.main_test_accuracy = quiz_machine.produce_results( - n_epoch=n_epoch, - model=model, - result_dir=args.result_dir, - deterministic_synthesis=deterministic_synthesis, - ) - - -def one_epoch(model, quiz_machine, local_device=None): - if local_device is None: - local_device = device - - optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) - - model.to(local_device).train() - - nb_train_samples, acc_train_loss = 0, 0.0 - - for input in quiz_machine.batches(model, split="train"): - input = input.to(local_device) - - if nb_train_samples % args.batch_size == 0: - optimizer.zero_grad() - - output = model(mygpt.BracketedSequence(input)).x - loss = F.cross_entropy(output.transpose(1, 2), input) - acc_train_loss += loss.item() * input.size(0) - - nb_train_samples += input.size(0) - - loss.backward() - - if nb_train_samples % args.batch_size == 0: - optimizer.step() - - train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) - - log_string(f"train_perplexity {n_epoch} model {model.id} {train_perplexity}") - - run_tests(model, quiz_machine, deterministic_synthesis=False) - - -###################################################################### - - -def standard_validity(logproba): - l = logproba.sort(dim=-1).values - return (l[:, 0] < math.log(0.5)) & (l[:, 1] > math.log(0.99)) - # warnings.warn("TEST!!!", RuntimeWarning) - # print(l.exp()) - # return (l[:, 0] < math.log(0.99)) - - -def valid_c_quizzes(recorded, criteria): - result = [q[criteria(lp)] for q, lp in recorded] - return torch.cat(result, dim=0) if len(result) > 0 else torch.tensor([]) - - -###################################################################### - - -def create_c_quizzes( - models, - quiz_machine, - nb_for_train=1000, - nb_for_test=100, -): - quizzes_and_logproba_records = [] - - nb_to_create = nb_for_train + nb_for_test - - # ------------------------------------------------------------ - - file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat") - - with open(file_name, "w") as logp_file: - while ( - valid_c_quizzes(quizzes_and_logproba_records, standard_validity).size(0) - < nb_to_create - ): - # Select a model at random to generate the new quizzes - - model_for_generation = models[torch.randint(len(models), (1,))] - - c_quizzes = quiz_machine.generate_quizzes( - nb_to_create, - model_for_generation=model_for_generation, - temperature=args.generation_temperature, - ) - - c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] - - if c_quizzes.size(0) > 0: - logproba = quiz_machine.logproba_of_solutions(models, c_quizzes) - for l in logproba: - s = " ".join([str(x.item()) for x in l]) - logp_file.write(s + "\n") - quizzes_and_logproba_records.append((c_quizzes, logproba)) - - nb_validated = valid_c_quizzes( - quizzes_and_logproba_records, standard_validity - ).size(0) - - log_string( - f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create}" - ) - - # store the new c_quizzes which have been validated - - new_c_quizzes = valid_c_quizzes(quizzes_and_logproba_records, standard_validity) - - quiz_machine.reverse_random_half_in_place(new_c_quizzes) - - quiz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True) - quiz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False) - - # save a bunch of images to investigate what quizzes with a - # certain nb of correct predictions look like - - q = new_c_quizzes[:72] - - if q.size(0) > 0: - quiz_machine.save_quizzes(args.result_dir, f"culture_c_quiz_{n_epoch:04d}", q) - - -###################################################################### - if torch.cuda.is_available(): device = torch.device("cuda") torch.backends.cuda.matmul.allow_tf32 = True @@ -258,7 +90,7 @@ parser.add_argument("--min_to_validate", type=int, default=None) parser.add_argument("--max_to_validate", type=int, default=None) -parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9) +parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975) parser.add_argument("--generation_temperature", type=float, default=2.0) @@ -266,6 +98,19 @@ parser.add_argument("--dirty_debug", action="store_true", default=False) ###################################################################### +grids_tasks = ", ".join( + [x.__name__.removeprefix("task_") for x in grids.Grids().all_tasks] +) + +parser.add_argument( + "--grids_tasks", + type=str, + default=None, + help="A comma-separated subset of: " + grids_tasks + ", or None for all.", +) + +###################################################################### + parser.add_argument("--sky_height", type=int, default=6) parser.add_argument("--sky_width", type=int, default=8) @@ -359,11 +204,6 @@ except FileExistsError: log_file = open(os.path.join(args.result_dir, args.log_filename), "a") -log_string(f"argv {' '.join(sys.argv)}") - -for n in vars(args): - log_string(f"args.{n} {getattr(args, n)}") - if args.seed >= 0: # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False @@ -374,6 +214,26 @@ if args.seed >= 0: ###################################################################### + +def log_string(s): + t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime()) + + if log_file is not None: + log_file.write(t + s + "\n") + log_file.flush() + + print(t + s) + sys.stdout.flush() + + +log_string(f"argv {' '.join(sys.argv)}") + +for n in vars(args): + log_string(f"args.{n} {getattr(args, n)}") + + +###################################################################### + if args.dirty_debug: args.nb_train_samples = 2500 args.nb_test_samples = 100 @@ -403,6 +263,7 @@ elif args.problem == "grids": max_nb_cached_chunks=args.nb_gpus * args.nb_train_samples // 100, chunk_size=100, nb_threads=args.nb_threads, + tasks=args.grids_tasks, ) back_accuracy = True else: @@ -431,6 +292,160 @@ log_string(f"vocabulary_size {vocabulary_size}") ###################################################################### + +###################################################################### + + +def run_tests(model, quiz_machine, deterministic_synthesis, local_device=None): + if local_device is None: + local_device = device + + with torch.autograd.no_grad(): + model.eval().to(local_device) + + nb_test_samples, acc_test_loss = 0, 0.0 + nb_samples_accumulated = 0 + + for input in quiz_machine.batches(model, split="test"): + input = input.to(local_device) + + bs = model(mygpt.BracketedSequence(input)) + output = bs.x + + loss = F.cross_entropy(output.transpose(1, 2), input) + + acc_test_loss += loss.item() * input.size(0) + + nb_test_samples += input.size(0) + + test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) + + log_string(f"test_perplexity {n_epoch} {test_perplexity}") + + model.main_test_accuracy = quiz_machine.produce_results( + n_epoch=n_epoch, + model=model, + result_dir=args.result_dir, + deterministic_synthesis=deterministic_synthesis, + ) + + +def one_epoch(model, quiz_machine, local_device=None): + if local_device is None: + local_device = device + + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + + model.to(local_device).train() + + nb_train_samples, acc_train_loss = 0, 0.0 + + for input in quiz_machine.batches(model, split="train"): + input = input.to(local_device) + + if nb_train_samples % args.batch_size == 0: + optimizer.zero_grad() + + output = model(mygpt.BracketedSequence(input)).x + loss = F.cross_entropy(output.transpose(1, 2), input) + acc_train_loss += loss.item() * input.size(0) + + nb_train_samples += input.size(0) + + loss.backward() + + if nb_train_samples % args.batch_size == 0: + optimizer.step() + + train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) + + log_string(f"train_perplexity {n_epoch} model.id {model.id} {train_perplexity}") + + run_tests(model, quiz_machine, deterministic_synthesis=False) + + +###################################################################### + + +def standard_validity(logproba): + l = logproba.sort(dim=-1).values + return (l[:, 0] < math.log(0.5)) & (l[:, 1] > math.log(0.99)) + + +def valid_c_quizzes(recorded, criteria): + result = [q[criteria(lp)] for q, lp in recorded] + return torch.cat(result, dim=0) if len(result) > 0 else torch.tensor([]) + + +###################################################################### + + +def create_c_quizzes( + models, + quiz_machine, + nb_for_train=1000, + nb_for_test=100, +): + quizzes_and_logproba_records = [] + + nb_to_create = nb_for_train + nb_for_test + + # ------------------------------------------------------------ + + file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat") + + with open(file_name, "w") as logp_file: + while ( + valid_c_quizzes(quizzes_and_logproba_records, standard_validity).size(0) + < nb_to_create + ): + # Select a model at random to generate the new quizzes + + model_for_generation = models[torch.randint(len(models), (1,))] + + c_quizzes = quiz_machine.generate_quizzes( + nb_to_create, + model_for_generation=model_for_generation, + temperature=args.generation_temperature, + ) + + c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] + + if c_quizzes.size(0) > 0: + logproba = quiz_machine.logproba_of_solutions(models, c_quizzes) + for l in logproba: + s = " ".join([str(x.item()) for x in l]) + logp_file.write(s + "\n") + quizzes_and_logproba_records.append((c_quizzes, logproba)) + + nb_validated = valid_c_quizzes( + quizzes_and_logproba_records, standard_validity + ).size(0) + + log_string( + f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create}" + ) + + # store the new c_quizzes which have been validated + + new_c_quizzes = valid_c_quizzes(quizzes_and_logproba_records, standard_validity) + + quiz_machine.reverse_random_half_in_place(new_c_quizzes) + + quiz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True) + quiz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False) + + # save a bunch of images to investigate what quizzes with a + # certain nb of correct predictions look like + + q = new_c_quizzes[:72] + + if q.size(0) > 0: + quiz_machine.save_quizzes(args.result_dir, f"culture_c_quiz_{n_epoch:04d}", q) + + +###################################################################### + models = [] for k in range(args.nb_gpts): @@ -449,13 +464,9 @@ for k in range(args.nb_gpts): model.main_test_accuracy = 0.0 model.id = k - model.train_w_quizzes = quiz_machine.generate_token_sequences( - args.nb_train_samples - ).to(device) + model.train_w_quizzes = quiz_machine.generate_token_sequences(args.nb_train_samples) quiz_machine.reverse_random_half_in_place(model.train_w_quizzes) - model.test_w_quizzes = quiz_machine.generate_token_sequences( - args.nb_test_samples - ).to(device) + model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples) quiz_machine.reverse_random_half_in_place(model.test_w_quizzes) models.append(model) @@ -531,7 +542,7 @@ if args.dirty_debug: def standard_validity(logproba): l = logproba.sort(dim=-1).values - return l[:, 0] < math.log(0.99) + return l[:, 0] < math.log(0.5) ###################################################################### @@ -543,38 +554,37 @@ for n_epoch in range(args.nb_epochs): log_string(f"current_test_accuracies {cta}") ################################################## - # Select, improve, and eval the worst models + # Select, improve, and eval the worst model ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy)) weakest_models = ranked_models[: args.nb_gpus] - processes = [] + threads = [] for gpu_id, model in enumerate(weakest_models): - log_string( - f"training model {model.id} main_test_accuracy {model.main_test_accuracy}" - ) + log_string(f"training model {model.id}") - process = mp.Process( - target=one_epoch, args=(model, quiz_machine, f"cuda:{gpu_id}") + t = threading.Thread( + target=one_epoch, daemon=True, args=(model, quiz_machine, f"cuda:{gpu_id}") ) - processes.append(process) + threads.append(t) - for process in processes: - process.start() + t.start() - for process in processes: - process.join() + for t in threads: + t.join() ################################################## - # Renew the train sets + # Replace a fraction of the w_quizzes with fresh ones log_string( f"cache_w_quizzes contains {quiz_machine.problem.nb_cached_quizzes()} quizzes" ) + # Renew entirely the train set + for model in weakest_models: quiz_machine.renew_w_quizzes(model, args.nb_train_samples)