e0588224f07ecafde8105b00c2b004e0b195e249
[culture.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, os, datetime, warnings
9
10 import torch, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import ffutils
15 import mygpt, tasks, problems
16
17 ######################################################################
18
19 if torch.cuda.is_available():
20     device = torch.device("cuda")
21     torch.backends.cuda.matmul.allow_tf32 = True
22 else:
23     device = torch.device("cpu")
24
25 ######################################################################
26
27 parser = argparse.ArgumentParser(
28     description="An implementation of GPT with cache.",
29     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
30 )
31
32 parser.add_argument("--task", type=str, default="world", help="world")
33
34 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
35
36 parser.add_argument("--result_dir", type=str, default=None)
37
38 parser.add_argument("--seed", type=int, default=0)
39
40 parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
41
42 ########################################
43
44 parser.add_argument("--nb_epochs", type=int, default=10000)
45
46 parser.add_argument("--batch_size", type=int, default=None)
47
48 parser.add_argument("--physical_batch_size", type=int, default=None)
49
50 parser.add_argument("--nb_train_samples", type=int, default=None)
51
52 parser.add_argument("--nb_test_samples", type=int, default=None)
53
54 parser.add_argument("--learning_rate", type=float, default=1e-4)
55
56 ########################################
57
58 parser.add_argument("--model", type=str, default=None)
59
60 parser.add_argument("--dim_model", type=int, default=None)
61
62 parser.add_argument("--dim_keys", type=int, default=None)
63
64 parser.add_argument("--dim_hidden", type=int, default=None)
65
66 parser.add_argument("--nb_heads", type=int, default=None)
67
68 parser.add_argument("--nb_blocks", type=int, default=None)
69
70 parser.add_argument("--dropout", type=float, default=0.1)
71
72 ########################################
73
74 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
75
76 parser.add_argument("--nb_gpts", type=int, default=5)
77
78 parser.add_argument("--check", action="store_true", default=False)
79
80 ######################################################################
81
82 args = parser.parse_args()
83
84 if args.result_dir is None:
85     args.result_dir = f"results_{args.task}"
86
87 ######################################################################
88
89 default_task_args = {
90     "world": {
91         "model": "37M",
92         "batch_size": 100,
93         "nb_train_samples": 250000,
94         "nb_test_samples": 10000,
95     },
96 }
97
98 if args.task in default_task_args:
99     for k, v in default_task_args[args.task].items():
100         if getattr(args, k) is None:
101             setattr(args, k, v)
102
103 ######################################################################
104
105 default_model_args = {
106     "17K": {
107         "dim_model": 32,
108         "dim_keys": 32,
109         "dim_hidden": 32,
110         "nb_heads": 2,
111         "nb_blocks": 2,
112     },
113     "4M": {
114         "dim_model": 256,
115         "dim_keys": 32,
116         "dim_hidden": 1024,
117         "nb_heads": 4,
118         "nb_blocks": 6,
119     },
120     "37M": {
121         "dim_model": 512,
122         "dim_keys": 64,
123         "dim_hidden": 2048,
124         "nb_heads": 8,
125         "nb_blocks": 12,
126     },
127     "122M": {
128         "dim_model": 768,
129         "dim_keys": 64,
130         "dim_hidden": 2048,
131         "nb_heads": 8,
132         "nb_blocks": 24,
133     },
134     "352M": {
135         "dim_model": 1024,
136         "dim_keys": 64,
137         "dim_hidden": 2048,
138         "nb_heads": 8,
139         "nb_blocks": 48,
140     },
141 }
142
143 if args.model in default_model_args:
144     for k, v in default_model_args[args.model].items():
145         if getattr(args, k) is None:
146             setattr(args, k, v)
147 else:
148     raise ValueError(f"Unknown model {args.model}")
149
150 ######################################################################
151
152 try:
153     os.mkdir(args.result_dir)
154 except FileExistsError:
155     print(f"result directory {args.result_dir} already exists")
156     exit(1)
157
158 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
159
160 if args.seed >= 0:
161     # torch.backends.cudnn.deterministic = True
162     # torch.backends.cudnn.benchmark = False
163     # torch.use_deterministic_algorithms(True)
164     torch.manual_seed(args.seed)
165     if torch.cuda.is_available():
166         torch.cuda.manual_seed_all(args.seed)
167
168 ######################################################################
169
170
171 def log_string(s):
172     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
173
174     if log_file is not None:
175         log_file.write(t + s + "\n")
176         log_file.flush()
177
178     print(t + s)
179     sys.stdout.flush()
180
181
182 log_string(f"argv {' '.join(sys.argv)}")
183
184 for n in vars(args):
185     log_string(f"args.{n} {getattr(args, n)}")
186
187
188 ######################################################################
189
190 if args.check:
191     args.nb_train_samples = 500
192     args.nb_test_samples = 100
193
194 if args.physical_batch_size is None:
195     args.physical_batch_size = args.batch_size
196 else:
197     assert args.batch_size % args.physical_batch_size == 0
198
199 assert args.nb_train_samples % args.batch_size == 0
200 assert args.nb_test_samples % args.batch_size == 0
201
202 if args.task == "file":
203     assert (
204         args.filetask_train_file is not None and args.filetask_test_file is not None
205     ), "You have to specify the task train and test files"
206     task = tasks.TaskFromFile(
207         args.filetask_train_file,
208         args.filetask_test_file,
209         nb_train_samples=args.nb_train_samples,
210         nb_test_samples=args.nb_test_samples,
211         batch_size=args.physical_batch_size,
212         shuffle=True,
213         device=device,
214     )
215     args.max_percents_of_test_in_train = 0
216
217 elif args.task == "byheart":
218     task = tasks.SandBox(
219         problem=problems.ProblemByHeart(separation=args.byheart_separation),
220         nb_train_samples=args.nb_train_samples,
221         nb_test_samples=args.nb_test_samples,
222         batch_size=args.physical_batch_size,
223         logger=log_string,
224         device=device,
225     )
226     args.max_percents_of_test_in_train = -1
227
228 elif args.task == "world":
229     task = tasks.World(
230         nb_train_samples=args.nb_train_samples,
231         nb_test_samples=args.nb_test_samples,
232         batch_size=args.physical_batch_size,
233         result_dir=args.result_dir,
234         logger=log_string,
235         device=device,
236     )
237     args.max_percents_of_test_in_train = -1
238
239 elif args.task == "learnop":
240     task = tasks.SandBox(
241         problem=problems.ProblemLearnOperator(),
242         nb_train_samples=args.nb_train_samples,
243         nb_test_samples=args.nb_test_samples,
244         batch_size=args.physical_batch_size,
245         logger=log_string,
246         device=device,
247     )
248
249
250 elif args.task == "guessop":
251     task = tasks.SandBox(
252         problem=problems.ProblemGuessOperator(),
253         nb_train_samples=args.nb_train_samples,
254         nb_test_samples=args.nb_test_samples,
255         batch_size=args.physical_batch_size,
256         logger=log_string,
257         device=device,
258     )
259
260
261 elif args.task == "twotargets":
262     task = tasks.SandBox(
263         problem=problems.ProblemTwoTargets(),
264         nb_train_samples=args.nb_train_samples,
265         nb_test_samples=args.nb_test_samples,
266         batch_size=args.physical_batch_size,
267         logger=log_string,
268         device=device,
269     )
270
271 elif args.task == "memory":
272     task = tasks.SandBox(
273         problem=problems.ProblemMemory(),
274         nb_train_samples=args.nb_train_samples,
275         nb_test_samples=args.nb_test_samples,
276         batch_size=args.physical_batch_size,
277         logger=log_string,
278         device=device,
279     )
280
281 elif args.task == "mixing":
282     task = tasks.SandBox(
283         problem=problems.ProblemMixing(
284             hard=args.mixing_hard, random_start=not args.mixing_deterministic_start
285         ),
286         nb_train_samples=args.nb_train_samples,
287         nb_test_samples=args.nb_test_samples,
288         batch_size=args.physical_batch_size,
289         logger=log_string,
290         device=device,
291     )
292
293 elif args.task == "addition":
294     task = tasks.SandBox(
295         problem=problems.ProblemAddition(),
296         nb_train_samples=args.nb_train_samples,
297         nb_test_samples=args.nb_test_samples,
298         batch_size=args.physical_batch_size,
299         logger=log_string,
300         device=device,
301     )
302
303 elif args.task == "picoclvr":
304     task = tasks.PicoCLVR(
305         nb_train_samples=args.nb_train_samples,
306         nb_test_samples=args.nb_test_samples,
307         batch_size=args.physical_batch_size,
308         height=args.picoclvr_height,
309         width=args.picoclvr_width,
310         nb_colors=args.picoclvr_nb_colors,
311         logger=log_string,
312         device=device,
313         pruner_train=picoclvr_pruner_train,
314         pruner_eval=picoclvr_pruner_eval,
315     )
316
317 elif args.task == "mnist":
318     task = tasks.MNIST(
319         nb_train_samples=args.nb_train_samples,
320         nb_test_samples=args.nb_test_samples,
321         batch_size=args.physical_batch_size,
322         device=device,
323     )
324
325 elif args.task == "maze":
326     task = tasks.Maze(
327         nb_train_samples=args.nb_train_samples,
328         nb_test_samples=args.nb_test_samples,
329         batch_size=args.physical_batch_size,
330         height=args.maze_height,
331         width=args.maze_width,
332         nb_walls=args.maze_nb_walls,
333         device="cpu",
334     )
335
336 elif args.task == "snake":
337     task = tasks.Snake(
338         nb_train_samples=args.nb_train_samples,
339         nb_test_samples=args.nb_test_samples,
340         batch_size=args.physical_batch_size,
341         height=args.snake_height,
342         width=args.snake_width,
343         nb_colors=args.snake_nb_colors,
344         length=args.snake_length,
345         prompt_length=args.snake_length // 2,
346         device=device,
347     )
348
349 elif args.task == "stack":
350     task = tasks.Stack(
351         nb_train_samples=args.nb_train_samples,
352         nb_test_samples=args.nb_test_samples,
353         batch_size=args.physical_batch_size,
354         logger=log_string,
355         nb_steps=args.stack_nb_steps,
356         nb_stacks=args.stack_nb_stacks,
357         nb_digits=args.stack_nb_digits,
358         fraction_values_for_train=args.stack_fraction_values_for_train,
359         device=device,
360     )
361
362 elif args.task == "expr":
363     task = tasks.Expr(
364         nb_train_samples=args.nb_train_samples,
365         nb_test_samples=args.nb_test_samples,
366         nb_variables=args.expr_nb_variables,
367         sequence_length=args.expr_sequence_length,
368         operand_max=args.expr_operand_max,
369         result_max=args.expr_result_max,
370         batch_size=args.physical_batch_size,
371         device=device,
372     )
373
374 elif args.task == "rpl":
375     task = tasks.RPL(
376         nb_train_samples=args.nb_train_samples,
377         nb_test_samples=args.nb_test_samples,
378         batch_size=args.physical_batch_size,
379         nb_starting_values=args.rpl_nb_starting_values,
380         max_input=args.rpl_max_input,
381         prog_len=args.rpl_prog_len,
382         nb_runs=args.rpl_nb_runs,
383         no_prog=args.rpl_no_prog,
384         logger=log_string,
385         device=device,
386     )
387
388 elif args.task == "grid":
389     task = tasks.Grid(
390         nb_train_samples=args.nb_train_samples,
391         nb_test_samples=args.nb_test_samples,
392         batch_size=args.physical_batch_size,
393         size=args.grid_size,
394         fraction_play=args.grid_fraction_play,
395         logger=log_string,
396         device=device,
397     )
398
399 elif args.task == "qmlp":
400     task = tasks.QMLP(
401         nb_train_samples=args.nb_train_samples,
402         nb_test_samples=args.nb_test_samples,
403         batch_size=args.physical_batch_size,
404         result_dir=args.result_dir,
405         logger=log_string,
406         device=device,
407     )
408
409 elif args.task == "greed":
410     task = tasks.Greed(
411         nb_train_samples=args.nb_train_samples,
412         nb_test_samples=args.nb_test_samples,
413         batch_size=args.physical_batch_size,
414         height=args.greed_height,
415         width=args.greed_width,
416         T=args.greed_T,
417         nb_walls=args.greed_nb_walls,
418         nb_coins=args.greed_nb_coins,
419         logger=log_string,
420         device=device,
421     )
422
423 else:
424     raise ValueError(f"Unknown task {args.task}")
425
426 ######################################################################
427
428 log_string(f"device {device}")
429
430 vocabulary_size = task.vocabulary_size()
431
432 log_string(f"vocabulary_size {vocabulary_size}")
433
434 ######################################################################
435
436 # Compute the entropy of the training tokens
437
438 token_count = 0
439 for input in task.batches(split="train", desc="train-entropy"):
440     token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
441 token_probas = token_count / token_count.sum()
442 entropy = -torch.xlogy(token_probas, token_probas).sum()
443 train_set_perplexity = math.exp(entropy)
444
445 ######################################################################
446 # A bit of paranoia never hurts
447
448 if args.max_percents_of_test_in_train >= 0:
449
450     def subsets_as_tuples(batches, cs):
451         s = set()
452         for batch in batches:
453             for x in batch:
454                 s.add(tuple([v.item() for v in x]))
455                 if len(s) == cs:
456                     yield s
457                     s = set()
458         yield s
459
460     nb_test, nb_in_train = 0, 0
461     for test_subset in subsets_as_tuples(
462         task.batches(split="test", desc="test-check"), 25000
463     ):
464         in_train = set()
465         for train_subset in subsets_as_tuples(
466             task.batches(split="train", desc="train-check"), 25000
467         ):
468             in_train.update(test_subset.intersection(train_subset))
469         nb_in_train += len(in_train)
470         nb_test += len(test_subset)
471
472     log_string(
473         f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
474     )
475
476     assert (
477         nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
478     ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
479
480 ##############################
481
482
483 def one_epoch(model, task):
484     optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
485
486     model.train()
487
488     nb_train_samples, acc_train_loss = 0, 0.0
489
490     for input in task.batches(split="train"):
491         input = input.to(device)
492
493         if nb_train_samples % args.batch_size == 0:
494             optimizer.zero_grad()
495
496         output = model(mygpt.BracketedSequence(input)).x
497         loss = F.cross_entropy(output.transpose(1, 2), input)
498         acc_train_loss += loss.item() * input.size(0)
499
500         nb_train_samples += input.size(0)
501
502         loss.backward()
503
504         if nb_train_samples % args.batch_size == 0:
505             optimizer.step()
506
507     train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
508
509     log_string(f"train_perplexity {n_epoch} {train_perplexity}")
510
511
512 ######################################################################
513
514
515 def run_tests(model, task, deterministic_synthesis):
516     with torch.autograd.no_grad():
517         model.eval()
518
519         nb_test_samples, acc_test_loss = 0, 0.0
520         nb_samples_accumulated = 0
521
522         for input in task.batches(split="test"):
523             input = input.to(device)
524
525             bs = model(mygpt.BracketedSequence(input))
526             output = bs.x
527
528             loss = F.cross_entropy(output.transpose(1, 2), input)
529
530             acc_test_loss += loss.item() * input.size(0)
531
532             nb_test_samples += input.size(0)
533
534         main_test_accuracy = task.produce_results(
535             n_epoch=n_epoch,
536             model=model,
537             result_dir=args.result_dir,
538             logger=log_string,
539             deterministic_synthesis=deterministic_synthesis,
540         )
541
542         test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
543
544         log_string(f"test_perplexity {n_epoch} {test_perplexity}")
545
546     model.main_test_accuracy = main_test_accuracy
547
548
549 ######################################################################
550
551
552 def create_quizzes(
553     model,
554     other_models,
555     task,
556     nb_for_train=1000,
557     nb_for_test=100,
558 ):
559     kept = []
560
561     while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test:
562         new_quizzes, nb_correct = task.create_new_quizzes(
563             n_epoch=n_epoch,
564             result_dir=args.result_dir,
565             logger=log_string,
566             nb=4 * (nb_for_train + nb_for_test),
567             model=model,
568             other_models=other_models,
569         )
570
571         to_keep = new_quizzes[nb_correct == len(other_models) - 1]
572         log_string(f"keep {to_keep.size(0)} quizzes")
573         kept.append(to_keep)
574
575     new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test]
576
577     task.store_new_quizzes(new_quizzes[:nb_for_train], for_train=True)
578     task.store_new_quizzes(new_quizzes[nb_for_train:], for_train=False)
579
580     task.save_image(
581         new_quizzes[:96],
582         args.result_dir,
583         f"world_quiz_{n_epoch:04d}_{model.id:02d}.png",
584         log_string,
585     )
586
587
588 ######################################################################
589
590 models = []
591
592 for k in range(args.nb_gpts):
593     model = mygpt.MyGPT(
594         vocabulary_size=vocabulary_size,
595         dim_model=args.dim_model,
596         dim_keys=args.dim_keys,
597         dim_hidden=args.dim_hidden,
598         nb_heads=args.nb_heads,
599         nb_blocks=args.nb_blocks,
600         causal=True,
601         dropout=args.dropout,
602     ).to(device)
603
604     model.main_test_accuracy = 0.0
605     model.id = k
606
607     models.append(model)
608
609
610 nb_parameters = sum(p.numel() for p in models[0].parameters())
611 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
612
613 ######################################################################
614
615 accuracy_to_make_quizzes = 0.975
616 nb_new_quizzes_for_train = 1000
617 nb_new_quizzes_for_test = 100
618
619 if args.check:
620     accuracy_to_make_quizzes = 0.0
621     nb_new_quizzes_for_train = 10
622     nb_new_quizzes_for_test = 10
623
624 for n_epoch in range(args.nb_epochs):
625     # select the model with lowest accuracy
626     models.sort(key=lambda model: model.main_test_accuracy)
627     model = models[0]
628
629     log_string(
630         f"training model {model.id} main_test_accuracy {model.main_test_accuracy}"
631     )
632
633     # improve it
634     one_epoch(model, task)
635
636     log_string(
637         f"train_set_composition world {task.nb_batch_samples_world} quizzes {task.nb_batch_samples_quizzes}"
638     )
639
640     # test it
641     run_tests(model, task, deterministic_synthesis=False)
642
643     if model.main_test_accuracy >= accuracy_to_make_quizzes:
644         other_models = models.copy()
645         other_models.remove(model)
646
647         create_quizzes(
648             model,
649             other_models,
650             task,
651             nb_for_train=nb_new_quizzes_for_train,
652             nb_for_test=nb_new_quizzes_for_test,
653         )
654
655
656 ######################################################################