import mygpt
import sky, grids, quiz_machine
+import threading
+
# world quizzes vs. culture quizzes
######################################################################
parser.add_argument("--seed", type=int, default=0)
-parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
+parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1)
########################################
parser.add_argument("--nb_threads", type=int, default=1)
+parser.add_argument("--nb_gpus", type=int, default=1)
+
parser.add_argument("--nb_gpts", type=int, default=5)
parser.add_argument("--min_to_validate", type=int, default=None)
nb_birds=args.sky_nb_birds,
nb_iterations=args.sky_nb_iterations,
speed=args.sky_speed,
- max_nb_cached_chunks=args.nb_train_samples // 100,
+ max_nb_cached_chunks=args.nb_gpus * args.nb_train_samples // 100,
chunk_size=100,
nb_threads=args.nb_threads,
)
back_accuracy = False
elif args.problem == "grids":
problem = grids.Grids(
- max_nb_cached_chunks=args.nb_train_samples // 100,
+ max_nb_cached_chunks=args.nb_gpus * args.nb_train_samples // 100,
chunk_size=100,
nb_threads=args.nb_threads,
)
log_string(f"vocabulary_size {vocabulary_size}")
######################################################################
-##############################
-
-
-def one_epoch(model, quiz_machine):
- optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
-
- model.train()
-
- nb_train_samples, acc_train_loss = 0, 0.0
-
- for input in quiz_machine.batches(model, split="train"):
- input = input.to(device)
-
- if nb_train_samples % args.batch_size == 0:
- optimizer.zero_grad()
-
- output = model(mygpt.BracketedSequence(input)).x
- loss = F.cross_entropy(output.transpose(1, 2), input)
- acc_train_loss += loss.item() * input.size(0)
-
- nb_train_samples += input.size(0)
-
- loss.backward()
-
- if nb_train_samples % args.batch_size == 0:
- optimizer.step()
-
- train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
-
- log_string(f"train_perplexity {n_epoch} {train_perplexity}")
######################################################################
-def run_tests(model, quiz_machine, deterministic_synthesis):
+def run_tests(model, quiz_machine, deterministic_synthesis, local_device=None):
+ if local_device is None:
+ local_device = device
+
with torch.autograd.no_grad():
- model.eval()
+ model.eval().to(local_device)
nb_test_samples, acc_test_loss = 0, 0.0
nb_samples_accumulated = 0
for input in quiz_machine.batches(model, split="test"):
- input = input.to(device)
+ input = input.to(local_device)
bs = model(mygpt.BracketedSequence(input))
output = bs.x
)
+def one_epoch(model, quiz_machine, local_device=None):
+ if local_device is None:
+ local_device = device
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+
+ model.to(local_device).train()
+
+ nb_train_samples, acc_train_loss = 0, 0.0
+
+ for input in quiz_machine.batches(model, split="train"):
+ input = input.to(local_device)
+
+ if nb_train_samples % args.batch_size == 0:
+ optimizer.zero_grad()
+
+ output = model(mygpt.BracketedSequence(input)).x
+ loss = F.cross_entropy(output.transpose(1, 2), input)
+ acc_train_loss += loss.item() * input.size(0)
+
+ nb_train_samples += input.size(0)
+
+ loss.backward()
+
+ if nb_train_samples % args.batch_size == 0:
+ optimizer.step()
+
+ train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+
+ log_string(f"train_perplexity {n_epoch} {train_perplexity}")
+
+ run_tests(model, quiz_machine, deterministic_synthesis=False)
+
+ model.TRAINING_LOCK.release()
+
+
######################################################################
model.main_test_accuracy = 0.0
model.id = k
+ model.TRAINING_LOCK = threading.Lock()
model.train_w_quizzes = quiz_machine.generate_token_sequences(
args.nb_train_samples
##################################################
# Select, improve, and eval the worst model
- weakest_model = min(models, key=lambda m: float(m.main_test_accuracy))
+ ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy))
- log_string(
- f"training model {weakest_model.id} main_test_accuracy {weakest_model.main_test_accuracy}"
- )
+ weakest_models = ranked_models[: args.nb_gpus]
- one_epoch(weakest_model, quiz_machine)
+ for gpu_id, model in enumerate(weakest_models):
+ model.TRAINING_LOCK.acquire()
- log_string(
- f"train_set_composition w_quizzes {quiz_machine.nb_batch_w_quizzes} c_quizzes {quiz_machine.nb_batch_c_quizzes}"
- )
+ log_string(
+ f"training model {model.id} main_test_accuracy {model.main_test_accuracy}"
+ )
- run_tests(weakest_model, quiz_machine, deterministic_synthesis=False)
+ threading.Thread(
+ target=one_epoch, daemon=True, args=(model, quiz_machine, f"cuda:{gpu_id}")
+ ).start()
- log_string(
- f"test_set_composition w_quizzes {quiz_machine.nb_batch_w_quizzes} c_quizzes {quiz_machine.nb_batch_c_quizzes}"
- )
+ for model in weakest_models:
+ model.TRAINING_LOCK.acquire()
+ model.TRAINING_LOCK.release()
##################################################
# Replace a fraction of the w_quizzes with fresh ones
# Renew entirely the train set
- quiz_machine.renew_w_quizzes(model, args.nb_train_samples)
+ for model in weakest_models:
+ quiz_machine.renew_w_quizzes(model, args.nb_train_samples)
##################################################
# If all the models are good enough, generate new quizzes and
nb_for_test=nb_new_c_quizzes_for_test,
)
- for model in models:
- run_tests(model, quiz_machine, deterministic_synthesis=False)
-
-
######################################################################
import mygpt
from mygpt import BracketedSequence
+import threading
+
######################################################################
# ar_mask is a tensor with 0s and 1s, of same shape as input, with
self.prompt_len = None
self.answer_len = None
- # self.train_w_quizzes = self.generate_token_sequences(nb_train_samples)
- # self.reverse_random_half_in_place(self.train_w_quizzes)
-
- # self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device)
- # self.reverse_random_half_in_place(self.test_w_quizzes)
-
+ self.LOCK_C_QUIZZES = threading.Lock()
self.train_c_quizzes = []
self.test_c_quizzes = []
- # if result_dir is not None:
- # self.save_quizzes(
- # result_dir,
- # "culture_w_quizzes",
- # self.train_w_quizzes[:72],
- # )
-
def save_quizzes(
self,
result_dir,
def batches(self, model, split="train", desc=None):
assert split in {"train", "test"}
- if split == "train":
- w_quizzes = model.train_w_quizzes
- c_quizzes = self.train_c_quizzes
- else:
- w_quizzes = model.test_w_quizzes
- c_quizzes = self.test_c_quizzes
- if len(c_quizzes) > 0:
- c_quizzes = torch.cat(c_quizzes, dim=0)
- if c_quizzes.size(0) > w_quizzes.size(0) // 2:
- i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
- c_quizzes = c_quizzes[i]
+ with self.LOCK_C_QUIZZES:
+ if split == "train":
+ w_quizzes = model.train_w_quizzes
+ c_quizzes = self.train_c_quizzes
+ else:
+ w_quizzes = model.test_w_quizzes
+ c_quizzes = self.test_c_quizzes
+
+ if len(c_quizzes) > 0:
+ c_quizzes = torch.cat(c_quizzes, dim=0)
+ if c_quizzes.size(0) > w_quizzes.size(0) // 2:
+ i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
+ c_quizzes = c_quizzes[i]
+
+ i = torch.randperm(w_quizzes.size(0))[
+ : w_quizzes.size(0) - c_quizzes.size(0)
+ ]
+ w_quizzes = w_quizzes[i]
- i = torch.randperm(w_quizzes.size(0))[
- : w_quizzes.size(0) - c_quizzes.size(0)
- ]
- w_quizzes = w_quizzes[i]
+ self.nb_batch_w_quizzes = w_quizzes.size(0)
+ self.nb_batch_c_quizzes = c_quizzes.size(0)
- self.nb_batch_w_quizzes = w_quizzes.size(0)
- self.nb_batch_c_quizzes = c_quizzes.size(0)
-
- input = torch.cat([w_quizzes, c_quizzes], dim=0)
- else:
- input = w_quizzes
- self.nb_batch_w_quizzes = w_quizzes.size(0)
- self.nb_batch_c_quizzes = 0
+ input = torch.cat([w_quizzes, c_quizzes], dim=0)
+ else:
+ input = w_quizzes
+ self.nb_batch_w_quizzes = w_quizzes.size(0)
+ self.nb_batch_c_quizzes = 0
# Shuffle
input = input[torch.randperm(input.size(0))]
######################################################################
def store_c_quizzes(self, new_c_quizzes, for_train=True):
- if for_train:
- self.train_c_quizzes.append(new_c_quizzes)
- else:
- self.test_c_quizzes.append(new_c_quizzes)
+ with self.LOCK_C_QUIZZES:
+ if for_train:
+ self.train_c_quizzes.append(new_c_quizzes)
+ else:
+ self.test_c_quizzes.append(new_c_quizzes)
######################################################################