3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, os, datetime, warnings
10 import torch, torchvision
12 from torch.nn import functional as F
17 ######################################################################
19 if torch.cuda.is_available():
20 device = torch.device("cuda")
21 torch.backends.cuda.matmul.allow_tf32 = True
23 device = torch.device("cpu")
25 ######################################################################
27 parser = argparse.ArgumentParser(
28 description="An implementation of GPT with cache.",
29 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
32 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
34 parser.add_argument("--result_dir", type=str, default=None)
36 parser.add_argument("--seed", type=int, default=0)
38 parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
40 ########################################
42 parser.add_argument("--nb_epochs", type=int, default=10000)
44 parser.add_argument("--batch_size", type=int, default=None)
46 parser.add_argument("--physical_batch_size", type=int, default=None)
48 parser.add_argument("--nb_train_samples", type=int, default=None)
50 parser.add_argument("--nb_test_samples", type=int, default=None)
52 parser.add_argument("--learning_rate", type=float, default=1e-4)
54 ########################################
56 parser.add_argument("--model", type=str, default=None)
58 parser.add_argument("--dim_model", type=int, default=None)
60 parser.add_argument("--dim_keys", type=int, default=None)
62 parser.add_argument("--dim_hidden", type=int, default=None)
64 parser.add_argument("--nb_heads", type=int, default=None)
66 parser.add_argument("--nb_blocks", type=int, default=None)
68 parser.add_argument("--dropout", type=float, default=0.1)
70 ########################################
72 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
74 parser.add_argument("--nb_gpts", type=int, default=5)
76 parser.add_argument("--check", action="store_true", default=False)
78 ######################################################################
80 args = parser.parse_args()
82 if args.result_dir is None:
83 args.result_dir = f"results_culture"
85 ######################################################################
90 "nb_train_samples": 250000,
91 "nb_test_samples": 10000,
94 for k, v in default_args.items():
95 if getattr(args, k) is None:
98 ######################################################################
100 default_model_args = {
138 if args.model in default_model_args:
139 for k, v in default_model_args[args.model].items():
140 if getattr(args, k) is None:
143 raise ValueError(f"Unknown model {args.model}")
145 ######################################################################
148 os.mkdir(args.result_dir)
149 except FileExistsError:
150 print(f"result directory {args.result_dir} already exists")
153 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
156 # torch.backends.cudnn.deterministic = True
157 # torch.backends.cudnn.benchmark = False
158 # torch.use_deterministic_algorithms(True)
159 torch.manual_seed(args.seed)
160 if torch.cuda.is_available():
161 torch.cuda.manual_seed_all(args.seed)
163 ######################################################################
167 t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
169 if log_file is not None:
170 log_file.write(t + s + "\n")
177 log_string(f"argv {' '.join(sys.argv)}")
180 log_string(f"args.{n} {getattr(args, n)}")
183 ######################################################################
186 args.nb_train_samples = 500
187 args.nb_test_samples = 100
189 if args.physical_batch_size is None:
190 args.physical_batch_size = args.batch_size
192 assert args.batch_size % args.physical_batch_size == 0
194 assert args.nb_train_samples % args.batch_size == 0
195 assert args.nb_test_samples % args.batch_size == 0
198 nb_train_samples=args.nb_train_samples,
199 nb_test_samples=args.nb_test_samples,
200 batch_size=args.physical_batch_size,
201 result_dir=args.result_dir,
206 ######################################################################
208 log_string(f"device {device}")
210 vocabulary_size = task.vocabulary_size()
212 log_string(f"vocabulary_size {vocabulary_size}")
214 ######################################################################
216 # Compute the entropy of the training tokens
219 for input in task.batches(split="train", desc="train-entropy"):
220 token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
221 token_probas = token_count / token_count.sum()
222 entropy = -torch.xlogy(token_probas, token_probas).sum()
223 train_set_perplexity = math.exp(entropy)
225 ######################################################################
226 # A bit of paranoia never hurts
228 if args.max_percents_of_test_in_train >= 0:
230 def subsets_as_tuples(batches, cs):
232 for batch in batches:
234 s.add(tuple([v.item() for v in x]))
240 nb_test, nb_in_train = 0, 0
241 for test_subset in subsets_as_tuples(
242 task.batches(split="test", desc="test-check"), 25000
245 for train_subset in subsets_as_tuples(
246 task.batches(split="train", desc="train-check"), 25000
248 in_train.update(test_subset.intersection(train_subset))
249 nb_in_train += len(in_train)
250 nb_test += len(test_subset)
253 f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
257 nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
258 ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
260 ##############################
263 def one_epoch(model, task):
264 optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
268 nb_train_samples, acc_train_loss = 0, 0.0
270 for input in task.batches(split="train"):
271 input = input.to(device)
273 if nb_train_samples % args.batch_size == 0:
274 optimizer.zero_grad()
276 output = model(mygpt.BracketedSequence(input)).x
277 loss = F.cross_entropy(output.transpose(1, 2), input)
278 acc_train_loss += loss.item() * input.size(0)
280 nb_train_samples += input.size(0)
284 if nb_train_samples % args.batch_size == 0:
287 train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
289 log_string(f"train_perplexity {n_epoch} {train_perplexity}")
292 ######################################################################
295 def run_tests(model, task, deterministic_synthesis):
296 with torch.autograd.no_grad():
299 nb_test_samples, acc_test_loss = 0, 0.0
300 nb_samples_accumulated = 0
302 for input in task.batches(split="test"):
303 input = input.to(device)
305 bs = model(mygpt.BracketedSequence(input))
308 loss = F.cross_entropy(output.transpose(1, 2), input)
310 acc_test_loss += loss.item() * input.size(0)
312 nb_test_samples += input.size(0)
314 main_test_accuracy = task.produce_results(
317 result_dir=args.result_dir,
319 deterministic_synthesis=deterministic_synthesis,
322 test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
324 log_string(f"test_perplexity {n_epoch} {test_perplexity}")
326 model.main_test_accuracy = main_test_accuracy
329 ######################################################################
341 while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test:
342 new_quizzes, nb_correct = task.create_new_quizzes(
344 result_dir=args.result_dir,
346 nb=4 * (nb_for_train + nb_for_test),
348 other_models=other_models,
353 to_keep = new_quizzes[nb_correct == len(other_models) - 1]
354 log_string(f"keep {to_keep.size(0)} quizzes")
357 new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test]
359 task.store_new_quizzes(new_quizzes[:nb_for_train], for_train=True)
360 task.store_new_quizzes(new_quizzes[nb_for_train:], for_train=False)
365 f"world_quiz_{n_epoch:04d}_{model.id:02d}.png",
370 ######################################################################
374 for k in range(args.nb_gpts):
376 vocabulary_size=vocabulary_size,
377 dim_model=args.dim_model,
378 dim_keys=args.dim_keys,
379 dim_hidden=args.dim_hidden,
380 nb_heads=args.nb_heads,
381 nb_blocks=args.nb_blocks,
383 dropout=args.dropout,
386 model.main_test_accuracy = 0.0
392 nb_parameters = sum(p.numel() for p in models[0].parameters())
393 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
395 ######################################################################
397 accuracy_to_make_quizzes = 0.975
398 nb_new_quizzes_for_train = 1000
399 nb_new_quizzes_for_test = 100
402 accuracy_to_make_quizzes = 0.0
403 nb_new_quizzes_for_train = 10
404 nb_new_quizzes_for_test = 10
406 for n_epoch in range(args.nb_epochs):
407 a = [(model.id, model.main_test_accuracy) for model in models]
408 a.sort(key=lambda p: p[0])
409 log_string(f"current accuracies {a}")
411 # select the model with lowest accuracy
412 models.sort(key=lambda model: model.main_test_accuracy)
416 f"training model {model.id} main_test_accuracy {model.main_test_accuracy}"
420 one_epoch(model, task)
423 f"train_set_composition world {task.nb_batch_samples_world} quizzes {task.nb_batch_samples_quizzes}"
427 run_tests(model, task, deterministic_synthesis=False)
429 if model.main_test_accuracy >= accuracy_to_make_quizzes:
430 other_models = models.copy()
431 other_models.remove(model)
437 nb_for_train=nb_new_quizzes_for_train,
438 nb_for_test=nb_new_quizzes_for_test,
443 run_tests(model, task, deterministic_synthesis=False)
446 ######################################################################