3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, os, datetime, warnings
10 import torch, torchvision
12 from torch.nn import functional as F
15 import mygpt, tasks, problems
17 ######################################################################
19 if torch.cuda.is_available():
20 device = torch.device("cuda")
21 torch.backends.cuda.matmul.allow_tf32 = True
23 device = torch.device("cpu")
25 ######################################################################
27 parser = argparse.ArgumentParser(
28 description="An implementation of GPT with cache.",
29 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
32 parser.add_argument("--task", type=str, default="world", help="world")
34 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
36 parser.add_argument("--result_dir", type=str, default=None)
38 parser.add_argument("--seed", type=int, default=0)
40 parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
42 ########################################
44 parser.add_argument("--nb_epochs", type=int, default=10000)
46 parser.add_argument("--batch_size", type=int, default=None)
48 parser.add_argument("--physical_batch_size", type=int, default=None)
50 parser.add_argument("--nb_train_samples", type=int, default=None)
52 parser.add_argument("--nb_test_samples", type=int, default=None)
54 parser.add_argument("--learning_rate", type=float, default=1e-4)
56 ########################################
58 parser.add_argument("--model", type=str, default=None)
60 parser.add_argument("--dim_model", type=int, default=None)
62 parser.add_argument("--dim_keys", type=int, default=None)
64 parser.add_argument("--dim_hidden", type=int, default=None)
66 parser.add_argument("--nb_heads", type=int, default=None)
68 parser.add_argument("--nb_blocks", type=int, default=None)
70 parser.add_argument("--dropout", type=float, default=0.1)
72 ########################################
74 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
76 parser.add_argument("--check", action="store_true", default=False)
78 ######################################################################
80 args = parser.parse_args()
82 if args.result_dir is None:
83 args.result_dir = f"results_{args.task}"
85 ######################################################################
91 "nb_train_samples": 250000,
92 "nb_test_samples": 10000,
96 if args.task in default_task_args:
97 for k, v in default_task_args[args.task].items():
98 if getattr(args, k) is None:
101 ######################################################################
103 default_model_args = {
141 if args.model in default_model_args:
142 for k, v in default_model_args[args.model].items():
143 if getattr(args, k) is None:
146 raise ValueError(f"Unknown model {args.model}")
148 ######################################################################
151 os.mkdir(args.result_dir)
152 except FileExistsError:
153 print(f"result directory {args.result_dir} already exists")
156 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
159 # torch.backends.cudnn.deterministic = True
160 # torch.backends.cudnn.benchmark = False
161 # torch.use_deterministic_algorithms(True)
162 torch.manual_seed(args.seed)
163 if torch.cuda.is_available():
164 torch.cuda.manual_seed_all(args.seed)
166 ######################################################################
170 t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
172 if log_file is not None:
173 log_file.write(t + s + "\n")
180 log_string(f"argv {' '.join(sys.argv)}")
183 log_string(f"args.{n} {getattr(args, n)}")
186 ######################################################################
189 args.nb_train_samples = 1000
190 args.nb_test_samples = 25
192 if args.physical_batch_size is None:
193 args.physical_batch_size = args.batch_size
195 assert args.batch_size % args.physical_batch_size == 0
197 assert args.nb_train_samples % args.batch_size == 0
198 assert args.nb_test_samples % args.batch_size == 0
200 if args.task == "file":
202 args.filetask_train_file is not None and args.filetask_test_file is not None
203 ), "You have to specify the task train and test files"
204 task = tasks.TaskFromFile(
205 args.filetask_train_file,
206 args.filetask_test_file,
207 nb_train_samples=args.nb_train_samples,
208 nb_test_samples=args.nb_test_samples,
209 batch_size=args.physical_batch_size,
213 args.max_percents_of_test_in_train = 0
215 elif args.task == "byheart":
216 task = tasks.SandBox(
217 problem=problems.ProblemByHeart(separation=args.byheart_separation),
218 nb_train_samples=args.nb_train_samples,
219 nb_test_samples=args.nb_test_samples,
220 batch_size=args.physical_batch_size,
224 args.max_percents_of_test_in_train = -1
226 elif args.task == "world":
228 nb_train_samples=args.nb_train_samples,
229 nb_test_samples=args.nb_test_samples,
230 batch_size=args.physical_batch_size,
231 result_dir=args.result_dir,
235 args.max_percents_of_test_in_train = -1
237 elif args.task == "learnop":
238 task = tasks.SandBox(
239 problem=problems.ProblemLearnOperator(),
240 nb_train_samples=args.nb_train_samples,
241 nb_test_samples=args.nb_test_samples,
242 batch_size=args.physical_batch_size,
248 elif args.task == "guessop":
249 task = tasks.SandBox(
250 problem=problems.ProblemGuessOperator(),
251 nb_train_samples=args.nb_train_samples,
252 nb_test_samples=args.nb_test_samples,
253 batch_size=args.physical_batch_size,
259 elif args.task == "twotargets":
260 task = tasks.SandBox(
261 problem=problems.ProblemTwoTargets(),
262 nb_train_samples=args.nb_train_samples,
263 nb_test_samples=args.nb_test_samples,
264 batch_size=args.physical_batch_size,
269 elif args.task == "memory":
270 task = tasks.SandBox(
271 problem=problems.ProblemMemory(),
272 nb_train_samples=args.nb_train_samples,
273 nb_test_samples=args.nb_test_samples,
274 batch_size=args.physical_batch_size,
279 elif args.task == "mixing":
280 task = tasks.SandBox(
281 problem=problems.ProblemMixing(
282 hard=args.mixing_hard, random_start=not args.mixing_deterministic_start
284 nb_train_samples=args.nb_train_samples,
285 nb_test_samples=args.nb_test_samples,
286 batch_size=args.physical_batch_size,
291 elif args.task == "addition":
292 task = tasks.SandBox(
293 problem=problems.ProblemAddition(),
294 nb_train_samples=args.nb_train_samples,
295 nb_test_samples=args.nb_test_samples,
296 batch_size=args.physical_batch_size,
301 elif args.task == "picoclvr":
302 task = tasks.PicoCLVR(
303 nb_train_samples=args.nb_train_samples,
304 nb_test_samples=args.nb_test_samples,
305 batch_size=args.physical_batch_size,
306 height=args.picoclvr_height,
307 width=args.picoclvr_width,
308 nb_colors=args.picoclvr_nb_colors,
311 pruner_train=picoclvr_pruner_train,
312 pruner_eval=picoclvr_pruner_eval,
315 elif args.task == "mnist":
317 nb_train_samples=args.nb_train_samples,
318 nb_test_samples=args.nb_test_samples,
319 batch_size=args.physical_batch_size,
323 elif args.task == "maze":
325 nb_train_samples=args.nb_train_samples,
326 nb_test_samples=args.nb_test_samples,
327 batch_size=args.physical_batch_size,
328 height=args.maze_height,
329 width=args.maze_width,
330 nb_walls=args.maze_nb_walls,
334 elif args.task == "snake":
336 nb_train_samples=args.nb_train_samples,
337 nb_test_samples=args.nb_test_samples,
338 batch_size=args.physical_batch_size,
339 height=args.snake_height,
340 width=args.snake_width,
341 nb_colors=args.snake_nb_colors,
342 length=args.snake_length,
343 prompt_length=args.snake_length // 2,
347 elif args.task == "stack":
349 nb_train_samples=args.nb_train_samples,
350 nb_test_samples=args.nb_test_samples,
351 batch_size=args.physical_batch_size,
353 nb_steps=args.stack_nb_steps,
354 nb_stacks=args.stack_nb_stacks,
355 nb_digits=args.stack_nb_digits,
356 fraction_values_for_train=args.stack_fraction_values_for_train,
360 elif args.task == "expr":
362 nb_train_samples=args.nb_train_samples,
363 nb_test_samples=args.nb_test_samples,
364 nb_variables=args.expr_nb_variables,
365 sequence_length=args.expr_sequence_length,
366 operand_max=args.expr_operand_max,
367 result_max=args.expr_result_max,
368 batch_size=args.physical_batch_size,
372 elif args.task == "rpl":
374 nb_train_samples=args.nb_train_samples,
375 nb_test_samples=args.nb_test_samples,
376 batch_size=args.physical_batch_size,
377 nb_starting_values=args.rpl_nb_starting_values,
378 max_input=args.rpl_max_input,
379 prog_len=args.rpl_prog_len,
380 nb_runs=args.rpl_nb_runs,
381 no_prog=args.rpl_no_prog,
386 elif args.task == "grid":
388 nb_train_samples=args.nb_train_samples,
389 nb_test_samples=args.nb_test_samples,
390 batch_size=args.physical_batch_size,
392 fraction_play=args.grid_fraction_play,
397 elif args.task == "qmlp":
399 nb_train_samples=args.nb_train_samples,
400 nb_test_samples=args.nb_test_samples,
401 batch_size=args.physical_batch_size,
402 result_dir=args.result_dir,
407 elif args.task == "greed":
409 nb_train_samples=args.nb_train_samples,
410 nb_test_samples=args.nb_test_samples,
411 batch_size=args.physical_batch_size,
412 height=args.greed_height,
413 width=args.greed_width,
415 nb_walls=args.greed_nb_walls,
416 nb_coins=args.greed_nb_coins,
422 raise ValueError(f"Unknown task {args.task}")
424 ######################################################################
426 log_string(f"device {device}")
428 vocabulary_size = task.vocabulary_size()
430 log_string(f"vocabulary_size {vocabulary_size}")
432 ######################################################################
434 # Compute the entropy of the training tokens
437 for input in task.batches(split="train", desc="train-entropy"):
438 token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
439 token_probas = token_count / token_count.sum()
440 entropy = -torch.xlogy(token_probas, token_probas).sum()
441 train_set_perplexity = math.exp(entropy)
443 ######################################################################
444 # A bit of paranoia never hurts
446 if args.max_percents_of_test_in_train >= 0:
448 def subsets_as_tuples(batches, cs):
450 for batch in batches:
452 s.add(tuple([v.item() for v in x]))
458 nb_test, nb_in_train = 0, 0
459 for test_subset in subsets_as_tuples(
460 task.batches(split="test", desc="test-check"), 25000
463 for train_subset in subsets_as_tuples(
464 task.batches(split="train", desc="train-check"), 25000
466 in_train.update(test_subset.intersection(train_subset))
467 nb_in_train += len(in_train)
468 nb_test += len(test_subset)
471 f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
475 nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
476 ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
478 ##############################
481 def one_epoch(model, task):
482 optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
486 nb_train_samples, acc_train_loss = 0, 0.0
488 for input in task.batches(split="train"):
489 input = input.to(device)
491 if nb_train_samples % args.batch_size == 0:
492 optimizer.zero_grad()
494 output = model(mygpt.BracketedSequence(input)).x
495 loss = F.cross_entropy(output.transpose(1, 2), input)
496 acc_train_loss += loss.item() * input.size(0)
498 nb_train_samples += input.size(0)
502 if nb_train_samples % args.batch_size == 0:
505 train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
507 log_string(f"train_perplexity {n_epoch} {train_perplexity}")
510 ######################################################################
513 def run_tests(model, task, deterministic_synthesis):
514 with torch.autograd.no_grad():
517 nb_test_samples, acc_test_loss = 0, 0.0
518 nb_samples_accumulated = 0
520 for input in task.batches(split="test"):
521 input = input.to(device)
523 bs = model(mygpt.BracketedSequence(input))
526 loss = F.cross_entropy(output.transpose(1, 2), input)
528 acc_test_loss += loss.item() * input.size(0)
530 nb_test_samples += input.size(0)
532 main_test_accuracy = task.produce_results(
535 result_dir=args.result_dir,
537 deterministic_synthesis=deterministic_synthesis,
540 test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
542 log_string(f"test_perplexity {n_epoch} {test_perplexity}")
544 model.main_test_accuracy = main_test_accuracy
547 ######################################################################
559 while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test:
560 new_quizzes, nb_correct = task.create_new_quizzes(
562 result_dir=args.result_dir,
564 nb=4 * (nb_for_train + nb_for_test),
566 other_models=other_models,
569 to_keep = new_quizzes[nb_correct == len(other_models) - 1]
570 log_string(f"keep {to_keep.size(0)} quizzes")
573 new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test]
575 task.store_new_quizzes(new_quizzes[:nb_for_train], for_train=True)
576 task.store_new_quizzes(new_quizzes[nb_for_train:], for_train=False)
581 f"world_new_{n_epoch:04d}_{model.id:02d}.png",
586 ######################################################################
592 vocabulary_size=vocabulary_size,
593 dim_model=args.dim_model,
594 dim_keys=args.dim_keys,
595 dim_hidden=args.dim_hidden,
596 nb_heads=args.nb_heads,
597 nb_blocks=args.nb_blocks,
599 dropout=args.dropout,
602 model.main_test_accuracy = 0.0
608 nb_parameters = sum(p.numel() for p in models[0].parameters())
609 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
611 ######################################################################
613 accuracy_to_make_quizzes = 0.975
614 nb_new_quizzes_for_train = 1000
615 nb_new_quizzes_for_test = 100
618 accuracy_to_make_quizzes = 0.0
619 nb_new_quizzes_for_train = 10
620 nb_new_quizzes_for_test = 10
622 for n_epoch in range(args.nb_epochs):
623 # select the model with lowest accuracy
624 models.sort(key=lambda model: model.main_test_accuracy)
628 f"training model {model.id} main_test_accuracy {model.main_test_accuracy}"
632 one_epoch(model, task)
635 f"train_set_composition world {task.nb_batch_samples_world} quizzes {task.nb_batch_samples_quizzes}"
639 run_tests(model, task, deterministic_synthesis=False)
641 if model.main_test_accuracy >= accuracy_to_make_quizzes:
642 other_models = models.copy()
643 other_models.remove(model)
649 nb_for_train=nb_new_quizzes_for_train,
650 nb_for_test=nb_new_quizzes_for_test,
654 ######################################################################