X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=b88cbc4b545c553e36629d43b174cfc056250bd9;hb=7c79c0b140c88a529962945ec5b482fe90c55581;hp=a7338c7bdf5402cfa6daded72f594c8a5f88c62d;hpb=a86dff174205c38d8e90d0d89ea399a6afb36359;p=culture.git diff --git a/main.py b/main.py index a7338c7..b88cbc4 100755 --- a/main.py +++ b/main.py @@ -255,6 +255,8 @@ elif args.problem == "grids": else: raise ValueError +problem.save_some_examples(args.result_dir) + quiz_machine = quiz_machine.QuizMachine( problem=problem, nb_train_samples=args.nb_train_samples, @@ -347,8 +349,6 @@ def one_epoch(model, quiz_machine, local_device=None): run_tests(model, quiz_machine, deterministic_synthesis=False) - model.TRAINING_LOCK.release() - ###################################################################### @@ -449,7 +449,6 @@ for k in range(args.nb_gpts): model.main_test_accuracy = 0.0 model.id = k - model.TRAINING_LOCK = threading.Lock() model.train_w_quizzes = quiz_machine.generate_token_sequences(args.nb_train_samples) quiz_machine.reverse_random_half_in_place(model.train_w_quizzes) @@ -547,20 +546,21 @@ for n_epoch in range(args.nb_epochs): weakest_models = ranked_models[: args.nb_gpus] + threads = [] + for gpu_id, model in enumerate(weakest_models): - model.TRAINING_LOCK.acquire() + log_string(f"training model {model.id}") - log_string( - f"training model {model.id} main_test_accuracy {model.main_test_accuracy}" + t = threading.Thread( + target=one_epoch, daemon=True, args=(model, quiz_machine, f"cuda:{gpu_id}") ) - threading.Thread( - target=one_epoch, daemon=True, args=(model, quiz_machine, f"cuda:{gpu_id}") - ).start() + threads.append(t) - for model in weakest_models: - model.TRAINING_LOCK.acquire() - model.TRAINING_LOCK.release() + t.start() + + for t in threads: + t.join() ################################################## # Replace a fraction of the w_quizzes with fresh ones