X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=main.py;h=b88cbc4b545c553e36629d43b174cfc056250bd9;hb=7c79c0b140c88a529962945ec5b482fe90c55581;hp=5956be5f91effe796d5b984e0447ddc7c4c46e67;hpb=07c065e77f1d2a775814ec402752a4a8eb6c7574;p=culture.git diff --git a/main.py b/main.py index 5956be5..b88cbc4 100755 --- a/main.py +++ b/main.py @@ -18,6 +18,8 @@ import sky, grids, quiz_machine import threading +import torch.multiprocessing as mp + # world quizzes vs. culture quizzes ###################################################################### @@ -88,7 +90,7 @@ parser.add_argument("--min_to_validate", type=int, default=None) parser.add_argument("--max_to_validate", type=int, default=None) -parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9) +parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975) parser.add_argument("--generation_temperature", type=float, default=2.0) @@ -343,12 +345,10 @@ def one_epoch(model, quiz_machine, local_device=None): train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) - log_string(f"train_perplexity {n_epoch} {train_perplexity}") + log_string(f"train_perplexity {n_epoch} model.id {model.id} {train_perplexity}") run_tests(model, quiz_machine, deterministic_synthesis=False) - model.TRAINING_LOCK.release() - ###################################################################### @@ -356,9 +356,6 @@ def one_epoch(model, quiz_machine, local_device=None): def standard_validity(logproba): l = logproba.sort(dim=-1).values return (l[:, 0] < math.log(0.5)) & (l[:, 1] > math.log(0.99)) - # warnings.warn("TEST!!!", RuntimeWarning) - # print(l.exp()) - # return (l[:, 0] < math.log(0.99)) def valid_c_quizzes(recorded, criteria): @@ -452,15 +449,10 @@ for k in range(args.nb_gpts): model.main_test_accuracy = 0.0 model.id = k - model.TRAINING_LOCK = threading.Lock() - model.train_w_quizzes = quiz_machine.generate_token_sequences( - args.nb_train_samples - ).to(device) + model.train_w_quizzes = quiz_machine.generate_token_sequences(args.nb_train_samples) quiz_machine.reverse_random_half_in_place(model.train_w_quizzes) - model.test_w_quizzes = quiz_machine.generate_token_sequences( - args.nb_test_samples - ).to(device) + model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples) quiz_machine.reverse_random_half_in_place(model.test_w_quizzes) models.append(model) @@ -536,7 +528,7 @@ if args.dirty_debug: def standard_validity(logproba): l = logproba.sort(dim=-1).values - return l[:, 0] < math.log(0.99) + return l[:, 0] < math.log(0.5) ###################################################################### @@ -548,34 +540,37 @@ for n_epoch in range(args.nb_epochs): log_string(f"current_test_accuracies {cta}") ################################################## - # Select, improve, and eval the worst models + # Select, improve, and eval the worst model ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy)) weakest_models = ranked_models[: args.nb_gpus] + threads = [] + for gpu_id, model in enumerate(weakest_models): - model.TRAINING_LOCK.acquire() + log_string(f"training model {model.id}") - log_string( - f"training model {model.id} main_test_accuracy {model.main_test_accuracy}" + t = threading.Thread( + target=one_epoch, daemon=True, args=(model, quiz_machine, f"cuda:{gpu_id}") ) - threading.Thread( - target=one_epoch, daemon=True, args=(model, quiz_machine, f"cuda:{gpu_id}") - ).start() + threads.append(t) - for model in weakest_models: - model.TRAINING_LOCK.acquire() - model.TRAINING_LOCK.release() + t.start() + + for t in threads: + t.join() ################################################## - # Renew the train sets + # Replace a fraction of the w_quizzes with fresh ones log_string( f"cache_w_quizzes contains {quiz_machine.problem.nb_cached_quizzes()} quizzes" ) + # Renew entirely the train set + for model in weakest_models: quiz_machine.renew_w_quizzes(model, args.nb_train_samples)