3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, os, datetime, warnings
10 import torch, torchvision
12 from torch.nn import functional as F
17 import sky, grids, quiz_machine
21 import torch.multiprocessing as mp
23 ######################################################################
25 parser = argparse.ArgumentParser(
26 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
29 parser.add_argument("--log_filename", type=str, default="train.log")
31 parser.add_argument("--result_dir", type=str, default=None)
33 parser.add_argument("--seed", type=int, default=0)
35 parser.add_argument("--resume", action="store_true", default=False)
37 parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1)
39 ########################################
41 parser.add_argument("--nb_epochs", type=int, default=10000)
43 parser.add_argument("--batch_size", type=int, default=None)
45 parser.add_argument("--physical_batch_size", type=int, default=None)
47 parser.add_argument("--nb_train_samples", type=int, default=None)
49 parser.add_argument("--nb_test_samples", type=int, default=None)
51 parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None)
53 parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None)
55 parser.add_argument("--learning_rate", type=float, default=5e-4)
57 ########################################
59 parser.add_argument("--model", type=str, default=None)
61 parser.add_argument("--dim_model", type=int, default=None)
63 parser.add_argument("--dim_keys", type=int, default=None)
65 parser.add_argument("--dim_hidden", type=int, default=None)
67 parser.add_argument("--nb_heads", type=int, default=None)
69 parser.add_argument("--nb_blocks", type=int, default=None)
71 parser.add_argument("--dropout", type=float, default=0.1)
73 ########################################
75 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
77 parser.add_argument("--problem", type=str, default="grids")
79 parser.add_argument("--nb_threads", type=int, default=1)
81 parser.add_argument("--gpus", type=str, default="all")
83 parser.add_argument("--nb_gpts", type=int, default=5)
85 parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9)
87 parser.add_argument("--proba_understands", type=float, default=0.9)
89 parser.add_argument("--proba_not_understands", type=float, default=0.5)
91 parser.add_argument("--generation_temperature", type=float, default=1.0)
93 parser.add_argument("--dirty_debug", action="store_true", default=False)
95 ######################################################################
97 grids_tasks = ", ".join(
98 [x.__name__.removeprefix("task_") for x in grids.Grids().all_tasks]
105 help="A comma-separated subset of: " + grids_tasks + ", or None for all.",
108 ######################################################################
110 parser.add_argument("--sky_height", type=int, default=6)
112 parser.add_argument("--sky_width", type=int, default=8)
114 parser.add_argument("--sky_nb_birds", type=int, default=3)
116 parser.add_argument("--sky_nb_iterations", type=int, default=2)
118 parser.add_argument("--sky_speed", type=int, default=3)
120 ######################################################################
122 args = parser.parse_args()
124 if args.result_dir is None:
125 args.result_dir = f"results_culture"
127 ######################################################################
132 "nb_train_samples": 100000,
133 "nb_test_samples": 10000,
136 for k, v in default_args.items():
137 if getattr(args, k) is None:
140 ######################################################################
142 default_model_args = {
180 if args.model in default_model_args:
181 for k, v in default_model_args[args.model].items():
182 if getattr(args, k) is None:
185 raise ValueError(f"Unknown model {args.model}")
187 ######################################################################
190 assert os.path.isdir(args.result_dir)
194 os.mkdir(args.result_dir)
195 except FileExistsError:
196 print(f"result directory {args.result_dir} already exists")
199 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
202 # torch.backends.cudnn.deterministic = True
203 # torch.backends.cudnn.benchmark = False
204 # torch.use_deterministic_algorithms(True)
205 torch.manual_seed(args.seed)
206 if torch.cuda.is_available():
207 torch.cuda.manual_seed_all(args.seed)
209 ######################################################################
213 t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
215 if log_file is not None:
216 log_file.write(t + s + "\n")
223 now = time.strftime("%Y%m%d-%H%M%S", time.localtime())
225 os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py")
227 log_string(f"argv {' '.join(sys.argv)}")
230 log_string(f"args.{n} {getattr(args, n)}")
233 ######################################################################
235 if args.gpus == "all":
236 gpus_idx = range(torch.cuda.device_count())
238 gpus_idx = [int(k) for k in args.gpus.split(",")]
240 gpus = [torch.device(f"cuda:{n}") for n in gpus_idx]
242 if torch.cuda.is_available():
243 main_device = gpus[0]
245 assert len(gpus) == 0
246 main_device = torch.device("cpu")
249 args.nb_train_samples = 2500
250 args.nb_test_samples = 100
252 if args.physical_batch_size is None:
253 args.physical_batch_size = args.batch_size
255 assert args.batch_size % args.physical_batch_size == 0
257 assert args.nb_train_samples % args.batch_size == 0
258 assert args.nb_test_samples % args.batch_size == 0
260 if args.problem == "sky":
262 height=args.sky_height,
263 width=args.sky_width,
264 nb_birds=args.sky_nb_birds,
265 nb_iterations=args.sky_nb_iterations,
266 speed=args.sky_speed,
267 max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
269 nb_threads=args.nb_threads,
271 back_accuracy = False
272 elif args.problem == "grids":
273 problem = grids.Grids(
274 max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
276 nb_threads=args.nb_threads,
277 tasks=args.grids_tasks,
283 problem.save_some_examples(args.result_dir)
285 quiz_machine = quiz_machine.QuizMachine(
287 nb_train_samples=args.nb_train_samples,
288 nb_test_samples=args.nb_test_samples,
289 back_accuracy=back_accuracy,
290 batch_size=args.physical_batch_size,
291 result_dir=args.result_dir,
296 ######################################################################
298 log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}")
300 vocabulary_size = quiz_machine.vocabulary_size()
302 log_string(f"vocabulary_size {vocabulary_size}")
304 ######################################################################
307 def run_tests(model, quiz_machine, deterministic_synthesis, local_device=main_device):
308 with torch.autograd.no_grad():
309 model.eval().to(local_device)
311 nb_test_samples, acc_test_loss = 0, 0.0
312 nb_samples_accumulated = 0
314 for input in quiz_machine.batches(model, split="test"):
315 input = input.to(local_device)
317 bs = model(mygpt.BracketedSequence(input))
320 loss = F.cross_entropy(output.transpose(1, 2), input)
322 acc_test_loss += loss.item() * input.size(0)
324 nb_test_samples += input.size(0)
326 test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
328 log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}")
330 model.main_test_accuracy = quiz_machine.produce_results(
333 result_dir=args.result_dir,
334 deterministic_synthesis=deterministic_synthesis,
338 def one_epoch(model, quiz_machine, local_device=main_device):
339 model.to(local_device).train()
341 optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
343 nb_train_samples, acc_train_loss = 0, 0.0
345 for input in quiz_machine.batches(model, split="train"):
346 input = input.to(local_device)
348 if nb_train_samples % args.batch_size == 0:
349 optimizer.zero_grad()
351 output = model(mygpt.BracketedSequence(input)).x
352 loss = F.cross_entropy(output.transpose(1, 2), input)
353 acc_train_loss += loss.item() * input.size(0)
355 nb_train_samples += input.size(0)
359 if nb_train_samples % args.batch_size == 0:
362 train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
364 log_string(f"train_perplexity {n_epoch} model {model.id} {train_perplexity}")
366 run_tests(model, quiz_machine, deterministic_synthesis=False)
368 model.to(main_device)
371 ######################################################################
373 # This is the key routine that decides what generated quizzes to keep
376 # token_logprobas are NxMxT where M is the number of models
379 def compute_valid_quizzes_(token_logprobas):
380 warnings.warn("validation with uniform constraints", RuntimeWarning)
381 l = token_logprobas.min(dim=-1).values.sort(dim=-1).values
382 return (l[:, 0] < math.log(0.1)) & (l[:, 1] > math.log(0.5))
385 def compute_valid_quizzes(token_logprobas):
386 l = token_logprobas.sum(dim=-1).sort(dim=-1).values
387 return (l[:, 0] < math.log(args.proba_not_understands)) & (
388 l[:, 1] > math.log(args.proba_understands)
392 def extract_valid_quizzes_and_logprobas(recorded):
393 validated_quizzes, validated_logprobas = [], []
394 for quizzes, token_logprobas in recorded:
395 validated_indices = compute_valid_quizzes(token_logprobas)
396 validated_quizzes.append(quizzes[validated_indices])
397 validated_logprobas.append(token_logprobas[validated_indices])
399 if len(validated_quizzes) > 0:
400 return torch.cat(validated_quizzes, dim=0), torch.cat(
401 validated_logprobas, dim=0
407 ######################################################################
410 def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
411 nb_to_create = nb_for_train + nb_for_test
413 recorded_quizzes_logprobas = []
417 while nb_validated < nb_to_create:
418 model_for_generation = models[torch.randint(len(models), (1,))]
420 c_quizzes = quiz_machine.generate_quizzes(
422 model_for_generation=model_for_generation,
423 temperature=args.generation_temperature,
426 c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
428 if c_quizzes.size(0) > 0:
429 token_logproba = quiz_machine.solution_token_logprobas(models, c_quizzes)
430 recorded_quizzes_logprobas.append((c_quizzes, token_logproba))
435 ) = extract_valid_quizzes_and_logprobas(recorded_quizzes_logprobas)
437 if validated_quizzes is not None:
438 nb_validated = validated_quizzes.size(0)
441 f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create}"
444 # store the new c_quizzes which have been validated
446 quiz_machine.reverse_random_half_in_place(validated_quizzes)
447 quiz_machine.store_c_quizzes(validated_quizzes[:nb_for_train], for_train=True)
448 quiz_machine.store_c_quizzes(
449 validated_quizzes[nb_for_train:nb_to_create], for_train=False
452 ######################################################################
453 # save images with their logprobas
455 vq = validated_quizzes[:72]
456 vl = validated_logprobas[:72]
459 prefix = f"culture_c_quiz_{n_epoch:04d}"
460 filename = os.path.join(args.result_dir, prefix + "_logp.pth")
461 torch.save(vl, filename)
462 # with open(file_name, "w") as logp_file:
464 # s = " ".join([str(x.item()) for x in l])
465 # logp_file.write(s + "\n")
467 quiz_machine.save_quiz_illustrations(args.result_dir, prefix, vq)
470 ######################################################################
474 for k in range(args.nb_gpts):
475 log_string(f"creating model {k} and its w_quizzes")
477 vocabulary_size=vocabulary_size,
478 dim_model=args.dim_model,
479 dim_keys=args.dim_keys,
480 dim_hidden=args.dim_hidden,
481 nb_heads=args.nb_heads,
482 nb_blocks=args.nb_blocks,
484 dropout=args.dropout,
487 model.main_test_accuracy = 0.0
490 model.train_w_quizzes = quiz_machine.generate_token_sequences(args.nb_train_samples)
491 quiz_machine.reverse_random_half_in_place(model.train_w_quizzes)
492 model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples)
493 quiz_machine.reverse_random_half_in_place(model.test_w_quizzes)
497 ######################################################################
502 filename = f"gpt_{model.id:03d}.pth"
505 d = torch.load(os.path.join(args.result_dir, filename))
506 model.load_state_dict(d[0])
507 model.main_test_accuracy = d[1]
508 log_string(f"successfully loaded {filename}")
509 except FileNotFoundError:
510 log_string(f"cannot find {filename}")
514 filename = "c_quizzes.pth"
515 quiz_machine.load_c_quizzes(os.path.join(args.result_dir, filename))
516 log_string(f"successfully loaded {filename}")
517 except FileNotFoundError:
518 log_string(f"cannot find {filename}")
522 log_string(f"error when loading {filename}.")
525 ######################################################################
527 nb_parameters = sum(p.numel() for p in models[0].parameters())
528 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
530 ######################################################################
532 # Compute the entropy of the training tokens
535 for input in quiz_machine.batches(models[0], split="train", desc="train-entropy"):
536 token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum(
539 token_probas = token_count / token_count.sum()
540 entropy = -torch.xlogy(token_probas, token_probas).sum()
541 train_set_perplexity = math.exp(entropy)
543 ######################################################################
544 # A bit of paranoia never hurts
546 if args.max_percents_of_test_in_train >= 0:
548 def subsets_as_tuples(batches, cs):
550 for batch in batches:
552 s.add(tuple([v.item() for v in x]))
558 nb_test, nb_in_train = 0, 0
559 for test_subset in subsets_as_tuples(
560 quiz_machine.batches(models[0], split="test", desc="test-check"), 25000
563 for train_subset in subsets_as_tuples(
564 quiz_machine.batches(models[0], split="train", desc="train-check"), 25000
566 in_train.update(test_subset.intersection(train_subset))
567 nb_in_train += len(in_train)
568 nb_test += len(test_subset)
571 f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
575 nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
576 ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
578 ######################################################################
580 if args.nb_new_c_quizzes_for_train is None:
581 args.nb_new_c_quizzes_for_train = args.nb_train_samples // 50
583 if args.nb_new_c_quizzes_for_test is None:
584 args.nb_new_c_quizzes_for_test = args.nb_test_samples // 50
587 f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}"
590 ######################################################################
593 args.accuracy_to_make_c_quizzes = 0.0
595 args.nb_new_c_quizzes_for_train = 100
596 args.nb_new_c_quizzes_for_test = 10
599 ######################################################################
601 for n_epoch in range(args.nb_epochs):
602 log_string(f"--- epoch {n_epoch} ----------------------------------------")
604 cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
605 log_string(f"current_test_accuracies {cta}")
607 ##################################################
608 # If all the models are good enough, generate new quizzes and
609 # re-compute the test errors
611 if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
615 nb_for_train=args.nb_new_c_quizzes_for_train,
616 nb_for_test=args.nb_new_c_quizzes_for_test,
619 filename = "c_quizzes.pth"
620 quiz_machine.save_c_quizzes(os.path.join(args.result_dir, filename))
621 log_string(f"wrote {filename}")
623 # Force one epoch of training
625 model.main_test_accuracy = 0.0
627 ##################################################
628 # Select, improve, and eval the worst model
630 ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy))
632 weakest_models = ranked_models[: len(gpus)]
636 for gpu, model in zip(gpus, weakest_models):
637 log_string(f"training model {model.id}")
639 t = threading.Thread(
640 target=one_epoch, daemon=True, args=(model, quiz_machine, gpu)
650 # Save the models to disk
652 for model in weakest_models:
653 filename = f"gpt_{model.id:03d}.pth"
655 (model.state_dict(), model.main_test_accuracy),
656 os.path.join(args.result_dir, filename),
658 log_string(f"wrote {filename}")
660 # Renew the training samples
662 for model in weakest_models:
663 quiz_machine.renew_w_quizzes(model, args.nb_train_samples)
666 ######################################################################