import threading
+import torch.multiprocessing as mp
+
# world quizzes vs. culture quizzes
######################################################################
parser.add_argument("--max_to_validate", type=int, default=None)
-parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9)
+parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975)
parser.add_argument("--generation_temperature", type=float, default=2.0)
train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
- log_string(f"train_perplexity {n_epoch} {train_perplexity}")
+ log_string(f"train_perplexity {n_epoch} model.id {model.id} {train_perplexity}")
run_tests(model, quiz_machine, deterministic_synthesis=False)
- model.TRAINING_LOCK.release()
-
######################################################################
def standard_validity(logproba):
l = logproba.sort(dim=-1).values
return (l[:, 0] < math.log(0.5)) & (l[:, 1] > math.log(0.99))
- # warnings.warn("TEST!!!", RuntimeWarning)
- # print(l.exp())
- # return (l[:, 0] < math.log(0.99))
def valid_c_quizzes(recorded, criteria):
model.main_test_accuracy = 0.0
model.id = k
- model.TRAINING_LOCK = threading.Lock()
- model.train_w_quizzes = quiz_machine.generate_token_sequences(
- args.nb_train_samples
- ).to(device)
+ model.train_w_quizzes = quiz_machine.generate_token_sequences(args.nb_train_samples)
quiz_machine.reverse_random_half_in_place(model.train_w_quizzes)
- model.test_w_quizzes = quiz_machine.generate_token_sequences(
- args.nb_test_samples
- ).to(device)
+ model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples)
quiz_machine.reverse_random_half_in_place(model.test_w_quizzes)
models.append(model)
def standard_validity(logproba):
l = logproba.sort(dim=-1).values
- return l[:, 0] < math.log(0.99)
+ return l[:, 0] < math.log(0.5)
######################################################################
log_string(f"current_test_accuracies {cta}")
##################################################
- # Select, improve, and eval the worst models
+ # Select, improve, and eval the worst model
ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy))
weakest_models = ranked_models[: args.nb_gpus]
+ threads = []
+
for gpu_id, model in enumerate(weakest_models):
- model.TRAINING_LOCK.acquire()
+ log_string(f"training model {model.id}")
- log_string(
- f"training model {model.id} main_test_accuracy {model.main_test_accuracy}"
+ t = threading.Thread(
+ target=one_epoch, daemon=True, args=(model, quiz_machine, f"cuda:{gpu_id}")
)
- threading.Thread(
- target=one_epoch, daemon=True, args=(model, quiz_machine, f"cuda:{gpu_id}")
- ).start()
+ threads.append(t)
- for model in weakest_models:
- model.TRAINING_LOCK.acquire()
- model.TRAINING_LOCK.release()
+ t.start()
+
+ for t in threads:
+ t.join()
##################################################
- # Renew the train sets
+ # Replace a fraction of the w_quizzes with fresh ones
log_string(
f"cache_w_quizzes contains {quiz_machine.problem.nb_cached_quizzes()} quizzes"
)
+ # Renew entirely the train set
+
for model in weakest_models:
quiz_machine.renew_w_quizzes(model, args.nb_train_samples)