97c71301fb8062cff3e66afb0bb00e4a815cce49
[culture.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, os, datetime, warnings
9
10 import torch, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import ffutils
15 import mygpt, tasks, problems
16
17 ######################################################################
18
19 if torch.cuda.is_available():
20     device = torch.device("cuda")
21     torch.backends.cuda.matmul.allow_tf32 = True
22 else:
23     device = torch.device("cpu")
24
25 ######################################################################
26
27 parser = argparse.ArgumentParser(
28     description="An implementation of GPT with cache.",
29     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
30 )
31
32 parser.add_argument("--task", type=str, default="world", help="world")
33
34 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
35
36 parser.add_argument("--result_dir", type=str, default=None)
37
38 parser.add_argument("--seed", type=int, default=0)
39
40 parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
41
42 ########################################
43
44 parser.add_argument("--nb_epochs", type=int, default=10000)
45
46 parser.add_argument("--batch_size", type=int, default=None)
47
48 parser.add_argument("--physical_batch_size", type=int, default=None)
49
50 parser.add_argument("--nb_train_samples", type=int, default=None)
51
52 parser.add_argument("--nb_test_samples", type=int, default=None)
53
54 parser.add_argument("--learning_rate", type=float, default=1e-4)
55
56 ########################################
57
58 parser.add_argument("--model", type=str, default=None)
59
60 parser.add_argument("--dim_model", type=int, default=None)
61
62 parser.add_argument("--dim_keys", type=int, default=None)
63
64 parser.add_argument("--dim_hidden", type=int, default=None)
65
66 parser.add_argument("--nb_heads", type=int, default=None)
67
68 parser.add_argument("--nb_blocks", type=int, default=None)
69
70 parser.add_argument("--dropout", type=float, default=0.1)
71
72 ########################################
73
74 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
75
76 ######################################################################
77
78 args = parser.parse_args()
79
80 if args.result_dir is None:
81     args.result_dir = f"results_{args.task}"
82
83 ######################################################################
84
85 default_task_args = {
86     "world": {
87         "model": "37M",
88         "batch_size": 100,
89         "nb_train_samples": 250000,
90         "nb_test_samples": 10000,
91     },
92 }
93
94 if args.task in default_task_args:
95     for k, v in default_task_args[args.task].items():
96         if getattr(args, k) is None:
97             setattr(args, k, v)
98
99 ######################################################################
100
101 default_model_args = {
102     "17K": {
103         "dim_model": 32,
104         "dim_keys": 32,
105         "dim_hidden": 32,
106         "nb_heads": 2,
107         "nb_blocks": 2,
108     },
109     "4M": {
110         "dim_model": 256,
111         "dim_keys": 32,
112         "dim_hidden": 1024,
113         "nb_heads": 4,
114         "nb_blocks": 6,
115     },
116     "37M": {
117         "dim_model": 512,
118         "dim_keys": 64,
119         "dim_hidden": 2048,
120         "nb_heads": 8,
121         "nb_blocks": 12,
122     },
123     "122M": {
124         "dim_model": 768,
125         "dim_keys": 64,
126         "dim_hidden": 2048,
127         "nb_heads": 8,
128         "nb_blocks": 24,
129     },
130     "352M": {
131         "dim_model": 1024,
132         "dim_keys": 64,
133         "dim_hidden": 2048,
134         "nb_heads": 8,
135         "nb_blocks": 48,
136     },
137 }
138
139 if args.model in default_model_args:
140     for k, v in default_model_args[args.model].items():
141         if getattr(args, k) is None:
142             setattr(args, k, v)
143 else:
144     raise ValueError(f"Unknown model {args.model}")
145
146 ######################################################################
147
148 try:
149     os.mkdir(args.result_dir)
150 except FileExistsError:
151     print(f"result directory {args.result_dir} already exists")
152     exit(1)
153
154 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
155
156 if args.seed >= 0:
157     # torch.backends.cudnn.deterministic = True
158     # torch.backends.cudnn.benchmark = False
159     # torch.use_deterministic_algorithms(True)
160     torch.manual_seed(args.seed)
161     if torch.cuda.is_available():
162         torch.cuda.manual_seed_all(args.seed)
163
164 ######################################################################
165
166
167 def log_string(s):
168     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
169
170     if log_file is not None:
171         log_file.write(t + s + "\n")
172         log_file.flush()
173
174     print(t + s)
175     sys.stdout.flush()
176
177
178 log_string(f"argv {' '.join(sys.argv)}")
179
180 for n in vars(args):
181     log_string(f"args.{n} {getattr(args, n)}")
182
183
184 ######################################################################
185
186
187 if args.physical_batch_size is None:
188     args.physical_batch_size = args.batch_size
189 else:
190     assert args.batch_size % args.physical_batch_size == 0
191
192 assert args.nb_train_samples % args.batch_size == 0
193 assert args.nb_test_samples % args.batch_size == 0
194
195 if args.task == "file":
196     assert (
197         args.filetask_train_file is not None and args.filetask_test_file is not None
198     ), "You have to specify the task train and test files"
199     task = tasks.TaskFromFile(
200         args.filetask_train_file,
201         args.filetask_test_file,
202         nb_train_samples=args.nb_train_samples,
203         nb_test_samples=args.nb_test_samples,
204         batch_size=args.physical_batch_size,
205         shuffle=True,
206         device=device,
207     )
208     args.max_percents_of_test_in_train = 0
209
210 elif args.task == "byheart":
211     task = tasks.SandBox(
212         problem=problems.ProblemByHeart(separation=args.byheart_separation),
213         nb_train_samples=args.nb_train_samples,
214         nb_test_samples=args.nb_test_samples,
215         batch_size=args.physical_batch_size,
216         logger=log_string,
217         device=device,
218     )
219     args.max_percents_of_test_in_train = -1
220
221 elif args.task == "world":
222     task = tasks.World(
223         nb_train_samples=args.nb_train_samples,
224         nb_test_samples=args.nb_test_samples,
225         batch_size=args.physical_batch_size,
226         result_dir=args.result_dir,
227         logger=log_string,
228         device=device,
229     )
230     args.max_percents_of_test_in_train = -1
231
232 elif args.task == "learnop":
233     task = tasks.SandBox(
234         problem=problems.ProblemLearnOperator(),
235         nb_train_samples=args.nb_train_samples,
236         nb_test_samples=args.nb_test_samples,
237         batch_size=args.physical_batch_size,
238         logger=log_string,
239         device=device,
240     )
241
242
243 elif args.task == "guessop":
244     task = tasks.SandBox(
245         problem=problems.ProblemGuessOperator(),
246         nb_train_samples=args.nb_train_samples,
247         nb_test_samples=args.nb_test_samples,
248         batch_size=args.physical_batch_size,
249         logger=log_string,
250         device=device,
251     )
252
253
254 elif args.task == "twotargets":
255     task = tasks.SandBox(
256         problem=problems.ProblemTwoTargets(),
257         nb_train_samples=args.nb_train_samples,
258         nb_test_samples=args.nb_test_samples,
259         batch_size=args.physical_batch_size,
260         logger=log_string,
261         device=device,
262     )
263
264 elif args.task == "memory":
265     task = tasks.SandBox(
266         problem=problems.ProblemMemory(),
267         nb_train_samples=args.nb_train_samples,
268         nb_test_samples=args.nb_test_samples,
269         batch_size=args.physical_batch_size,
270         logger=log_string,
271         device=device,
272     )
273
274 elif args.task == "mixing":
275     task = tasks.SandBox(
276         problem=problems.ProblemMixing(
277             hard=args.mixing_hard, random_start=not args.mixing_deterministic_start
278         ),
279         nb_train_samples=args.nb_train_samples,
280         nb_test_samples=args.nb_test_samples,
281         batch_size=args.physical_batch_size,
282         logger=log_string,
283         device=device,
284     )
285
286 elif args.task == "addition":
287     task = tasks.SandBox(
288         problem=problems.ProblemAddition(),
289         nb_train_samples=args.nb_train_samples,
290         nb_test_samples=args.nb_test_samples,
291         batch_size=args.physical_batch_size,
292         logger=log_string,
293         device=device,
294     )
295
296 elif args.task == "picoclvr":
297     task = tasks.PicoCLVR(
298         nb_train_samples=args.nb_train_samples,
299         nb_test_samples=args.nb_test_samples,
300         batch_size=args.physical_batch_size,
301         height=args.picoclvr_height,
302         width=args.picoclvr_width,
303         nb_colors=args.picoclvr_nb_colors,
304         logger=log_string,
305         device=device,
306         pruner_train=picoclvr_pruner_train,
307         pruner_eval=picoclvr_pruner_eval,
308     )
309
310 elif args.task == "mnist":
311     task = tasks.MNIST(
312         nb_train_samples=args.nb_train_samples,
313         nb_test_samples=args.nb_test_samples,
314         batch_size=args.physical_batch_size,
315         device=device,
316     )
317
318 elif args.task == "maze":
319     task = tasks.Maze(
320         nb_train_samples=args.nb_train_samples,
321         nb_test_samples=args.nb_test_samples,
322         batch_size=args.physical_batch_size,
323         height=args.maze_height,
324         width=args.maze_width,
325         nb_walls=args.maze_nb_walls,
326         device="cpu",
327     )
328
329 elif args.task == "snake":
330     task = tasks.Snake(
331         nb_train_samples=args.nb_train_samples,
332         nb_test_samples=args.nb_test_samples,
333         batch_size=args.physical_batch_size,
334         height=args.snake_height,
335         width=args.snake_width,
336         nb_colors=args.snake_nb_colors,
337         length=args.snake_length,
338         prompt_length=args.snake_length // 2,
339         device=device,
340     )
341
342 elif args.task == "stack":
343     task = tasks.Stack(
344         nb_train_samples=args.nb_train_samples,
345         nb_test_samples=args.nb_test_samples,
346         batch_size=args.physical_batch_size,
347         logger=log_string,
348         nb_steps=args.stack_nb_steps,
349         nb_stacks=args.stack_nb_stacks,
350         nb_digits=args.stack_nb_digits,
351         fraction_values_for_train=args.stack_fraction_values_for_train,
352         device=device,
353     )
354
355 elif args.task == "expr":
356     task = tasks.Expr(
357         nb_train_samples=args.nb_train_samples,
358         nb_test_samples=args.nb_test_samples,
359         nb_variables=args.expr_nb_variables,
360         sequence_length=args.expr_sequence_length,
361         operand_max=args.expr_operand_max,
362         result_max=args.expr_result_max,
363         batch_size=args.physical_batch_size,
364         device=device,
365     )
366
367 elif args.task == "rpl":
368     task = tasks.RPL(
369         nb_train_samples=args.nb_train_samples,
370         nb_test_samples=args.nb_test_samples,
371         batch_size=args.physical_batch_size,
372         nb_starting_values=args.rpl_nb_starting_values,
373         max_input=args.rpl_max_input,
374         prog_len=args.rpl_prog_len,
375         nb_runs=args.rpl_nb_runs,
376         no_prog=args.rpl_no_prog,
377         logger=log_string,
378         device=device,
379     )
380
381 elif args.task == "grid":
382     task = tasks.Grid(
383         nb_train_samples=args.nb_train_samples,
384         nb_test_samples=args.nb_test_samples,
385         batch_size=args.physical_batch_size,
386         size=args.grid_size,
387         fraction_play=args.grid_fraction_play,
388         logger=log_string,
389         device=device,
390     )
391
392 elif args.task == "qmlp":
393     task = tasks.QMLP(
394         nb_train_samples=args.nb_train_samples,
395         nb_test_samples=args.nb_test_samples,
396         batch_size=args.physical_batch_size,
397         result_dir=args.result_dir,
398         logger=log_string,
399         device=device,
400     )
401
402 elif args.task == "greed":
403     task = tasks.Greed(
404         nb_train_samples=args.nb_train_samples,
405         nb_test_samples=args.nb_test_samples,
406         batch_size=args.physical_batch_size,
407         height=args.greed_height,
408         width=args.greed_width,
409         T=args.greed_T,
410         nb_walls=args.greed_nb_walls,
411         nb_coins=args.greed_nb_coins,
412         logger=log_string,
413         device=device,
414     )
415
416 else:
417     raise ValueError(f"Unknown task {args.task}")
418
419 ######################################################################
420
421 log_string(f"device {device}")
422
423 vocabulary_size = task.vocabulary_size()
424
425 log_string(f"vocabulary_size {vocabulary_size}")
426
427 ######################################################################
428
429 # Compute the entropy of the training tokens
430
431 token_count = 0
432 for input in task.batches(split="train", desc="train-entropy"):
433     token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
434 token_probas = token_count / token_count.sum()
435 entropy = -torch.xlogy(token_probas, token_probas).sum()
436 train_set_perplexity = math.exp(entropy)
437
438 ######################################################################
439 # A bit of paranoia never hurts
440
441 if args.max_percents_of_test_in_train >= 0:
442
443     def subsets_as_tuples(batches, cs):
444         s = set()
445         for batch in batches:
446             for x in batch:
447                 s.add(tuple([v.item() for v in x]))
448                 if len(s) == cs:
449                     yield s
450                     s = set()
451         yield s
452
453     nb_test, nb_in_train = 0, 0
454     for test_subset in subsets_as_tuples(
455         task.batches(split="test", desc="test-check"), 25000
456     ):
457         in_train = set()
458         for train_subset in subsets_as_tuples(
459             task.batches(split="train", desc="train-check"), 25000
460         ):
461             in_train.update(test_subset.intersection(train_subset))
462         nb_in_train += len(in_train)
463         nb_test += len(test_subset)
464
465     log_string(
466         f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
467     )
468
469     assert (
470         nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
471     ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
472
473 ##############################
474
475
476 def one_epoch(model, task):
477     optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
478
479     model.train()
480
481     nb_train_samples, acc_train_loss = 0, 0.0
482
483     for input in task.batches(split="train"):
484         input = input.to(device)
485
486         if nb_train_samples % args.batch_size == 0:
487             optimizer.zero_grad()
488
489         output = model(mygpt.BracketedSequence(input)).x
490         loss = F.cross_entropy(output.transpose(1, 2), input)
491         acc_train_loss += loss.item() * input.size(0)
492
493         nb_train_samples += input.size(0)
494
495         loss.backward()
496
497         if nb_train_samples % args.batch_size == 0:
498             optimizer.step()
499
500     train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
501
502     log_string(f"train_perplexity {n_epoch} {train_perplexity}")
503
504
505 ######################################################################
506
507
508 def run_tests(model, task, deterministic_synthesis):
509     with torch.autograd.no_grad():
510         model.eval()
511
512         nb_test_samples, acc_test_loss = 0, 0.0
513         nb_samples_accumulated = 0
514
515         for input in task.batches(split="test"):
516             input = input.to(device)
517
518             bs = model(mygpt.BracketedSequence(input))
519             output = bs.x
520
521             loss = F.cross_entropy(output.transpose(1, 2), input)
522
523             acc_test_loss += loss.item() * input.size(0)
524
525             nb_test_samples += input.size(0)
526
527         main_test_accuracy = task.produce_results(
528             n_epoch=n_epoch,
529             model=model,
530             result_dir=args.result_dir,
531             logger=log_string,
532             deterministic_synthesis=deterministic_synthesis,
533         )
534
535         test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
536
537         log_string(f"test_perplexity {n_epoch} {test_perplexity}")
538
539     model.main_test_accuracy = main_test_accuracy
540
541
542 ######################################################################
543
544
545 def create_quizzes(
546     model,
547     other_models,
548     task,
549     nb_for_train=1000,
550     nb_for_test=100,
551 ):
552     kept = []
553
554     while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test:
555         new_quizzes, nb_correct = task.create_new_quizzes(
556             n_epoch=n_epoch,
557             result_dir=args.result_dir,
558             logger=log_string,
559             nb=4 * (nb_for_train + nb_for_test),
560             model=model,
561             other_models=other_models,
562         )
563
564         to_keep = new_quizzes[nb_correct == len(other_models) - 1]
565         log_string(f"keep {to_keep.size(0)} quizzes")
566         kept.append(to_keep)
567
568     new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test]
569
570     task.store_new_quizzes(new_quizzes[:nb_for_train], for_train=True)
571     task.store_new_quizzes(new_quizzes[nb_for_train:], for_train=False)
572
573     task.save_image(
574         new_quizzes[:96],
575         args.result_dir,
576         f"world_new_{n_epoch:04d}_{model.id:02d}.png",
577         log_string,
578     )
579
580
581 ######################################################################
582
583 models = []
584
585 for k in range(5):
586     model = mygpt.MyGPT(
587         vocabulary_size=vocabulary_size,
588         dim_model=args.dim_model,
589         dim_keys=args.dim_keys,
590         dim_hidden=args.dim_hidden,
591         nb_heads=args.nb_heads,
592         nb_blocks=args.nb_blocks,
593         causal=True,
594         dropout=args.dropout,
595     ).to(device)
596
597     model.main_test_accuracy = 0.0
598     model.id = k
599
600     models.append(model)
601
602
603 nb_parameters = sum(p.numel() for p in models[0].parameters())
604 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
605
606 ######################################################################
607
608 accuracy_to_make_quizzes = 0.975
609
610 for n_epoch in range(args.nb_epochs):
611     # select the model with lowest accuracy
612     models.sort(key=lambda model: model.main_test_accuracy)
613     model = models[0]
614
615     log_string(
616         f"training model {model.id} main_test_accuracy {model.main_test_accuracy}"
617     )
618
619     # improve it
620     one_epoch(model, task)
621
622     log_string(
623         f"train_set_composition world {task.nb_batch_samples_world} quizzes {task.nb_batch_samples_quizzes}"
624     )
625
626     # test it
627     run_tests(model, task, deterministic_synthesis=False)
628
629     if model.main_test_accuracy >= accuracy_to_make_quizzes:
630         other_models = models.copy()
631         other_models.remove(model)
632
633         create_quizzes(
634             model,
635             other_models,
636             task,
637             nb_for_train=1000,
638             nb_for_test=100,
639         )
640
641
642 ######################################################################