Added default configurations and reformated with black.
[mygpt.git] / main.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, argparse, time, tqdm, itertools
9
10 import torch, torchtext, torchvision
11 from torch import nn
12 from torch.nn import functional as F
13
14 import mygpt
15
16 ######################################################################
17
18 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
20 ######################################################################
21 parser = argparse.ArgumentParser(description="My own GPT.")
22
23 parser.add_argument("--log_filename", type=str, default="train.log")
24
25 parser.add_argument("--seed", type=int, default=0)
26
27 parser.add_argument("--nb_epochs", type=int, default=None)
28
29 parser.add_argument("--batch_size", type=int, default=25)
30
31 parser.add_argument("--data", type=str, default="wiki103")
32
33 parser.add_argument("--data_size", type=int, default=None)
34
35 parser.add_argument("--optim", type=str, default="adam")
36
37 parser.add_argument("--learning_rate", type=float, default=1e-3)
38
39 parser.add_argument("--learning_rate_end", type=float, default=1e-6)
40
41 parser.add_argument("--dim_model", type=int, default=None)
42
43 parser.add_argument("--dim_keys", type=int, default=None)
44
45 parser.add_argument("--dim_hidden", type=int, default=None)
46
47 parser.add_argument("--nb_heads", type=int, default=None)
48
49 parser.add_argument("--nb_blocks", type=int, default=None)
50
51 parser.add_argument("--dropout", type=float, default=0.1)
52
53 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
54
55 parser.add_argument("--no_checkpoint", action="store_true", default=False)
56
57 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
58
59 ##############################
60 # picoclvr options
61
62 parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
63
64 parser.add_argument("--picoclvr_height", type=int, default=12)
65
66 parser.add_argument("--picoclvr_width", type=int, default=16)
67
68 ######################################################################
69
70 args = parser.parse_args()
71
72 log_file = open(args.log_filename, "w")
73
74 if args.seed >= 0:
75     torch.manual_seed(args.seed)
76
77 ######################################################################
78
79
80 def log_string(s):
81     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
82
83     if log_file is not None:
84         log_file.write(t + s + "\n")
85         log_file.flush()
86
87     print(t + s)
88     sys.stdout.flush()
89
90
91 for n in vars(args):
92     log_string(f"args.{n} {getattr(args, n)}")
93
94 ######################################################################
95
96 default_args = {
97     "mnist": {
98         "nb_epochs": 10,
99         "dim_model": 64,
100         "dim_keys": 64,
101         "dim_hidden": 128,
102         "nb_heads": 4,
103         "nb_blocks": 6,
104     },
105     "mnist-debug": {
106         "nb_epochs": 2,
107         "data_size": 10000,
108         "dim_model": 8,
109         "dim_keys": 8,
110         "dim_hidden": 8,
111         "nb_heads": 2,
112         "nb_blocks": 4,
113     },
114     "wiki103": {
115         "nb_epochs": 25,
116         "dim_model": 512,
117         "dim_keys": 64,
118         "dim_hidden": 2048,
119         "nb_heads": 8,
120         "nb_blocks": 12,
121     },
122     "picoclvr": {
123         "nb_epochs": 25,
124         "dim_model": 512,
125         "dim_keys": 64,
126         "dim_hidden": 2048,
127         "nb_heads": 8,
128         "nb_blocks": 12,
129     },
130 }
131
132 if args.data in default_args:
133     for k, v in default_args[args.data].items():
134         if getattr(args, k) is None:
135             setattr(args, k, v)
136
137 ######################################################################
138
139
140 def autoregression(
141     model,
142     batch_size,
143     nb_samples,
144     nb_tokens_to_generate,
145     primer=None,
146     device=torch.device("cpu"),
147 ):
148     results = torch.zeros(
149         nb_samples, nb_tokens_to_generate, dtype=torch.int64, device=device
150     )
151
152     if primer is None:
153         first = 0
154     else:
155         first = primer.size(1)
156         results = torch.cat((primer, results), 1)
157
158     for input in results.split(batch_size):
159         for s in range(first, input.size(1)):
160             output = model(input)
161             logits = output[:, s]
162             if args.deterministic_synthesis:
163                 t_next = logits.argmax(1)
164             else:
165                 dist = torch.distributions.categorical.Categorical(logits=logits)
166                 t_next = dist.sample()
167             input[:, s] = t_next
168
169     return results
170
171
172 ######################################################################
173
174
175 class Task:
176     def batches(self, split="train"):
177         pass
178
179     def vocabulary_size(self):
180         pass
181
182     def produce_results(self, n_epoch, model):
183         pass
184
185
186 ######################################################################
187
188 import picoclvr
189
190
191 class TaskPicoCLVR(Task):
192
193     # Make a tensor from a list of strings
194     def tensorize(self, descr):
195         token_descr = [s.strip().split(" ") for s in descr]
196         l = max([len(s) for s in token_descr])
197         padded_token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
198         id_descr = [[self.token2id[u] for u in s] for s in padded_token_descr]
199         return torch.tensor(id_descr, device=self.device)
200
201     def trim(self, x, token="<nul>"):
202         n = self.token2id[token]
203         i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
204         a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
205         return x[:, a:b]
206
207     def __init__(
208         self, batch_size, height, width, nb_colors=5, device=torch.device("cpu")
209     ):
210         def generate_descr(nb):
211             return picoclvr.generate(
212                 nb, height=self.height, width=self.width, nb_colors=nb_colors
213             )
214
215         self.height = height
216         self.width = width
217         self.batch_size = batch_size
218         self.device = device
219         nb = args.data_size if args.data_size is not None else 250000
220
221         log_string(f"generating {nb} samples (can take some time)")
222         self.train_descr = generate_descr((nb * 4) // 5)
223         self.test_descr = generate_descr((nb * 1) // 5)
224
225         # Build the tokenizer
226         tokens = {"<nul>"}
227         for d in [self.train_descr, self.test_descr]:
228             for s in d:
229                 for t in s.strip().split(" "):
230                     tokens.add(t)
231         self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
232         self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
233
234         # Tokenize the train and test sets
235         self.train_input = self.tensorize(self.train_descr)
236         self.test_input = self.tensorize(self.test_descr)
237
238     def batches(self, split="train"):
239         assert split in {"train", "test"}
240         input = self.train_input if split == "train" else self.test_input
241         for batch in tqdm.tqdm(input.split(self.batch_size), desc=f"epoch-{split}"):
242             yield self.trim(batch)
243
244     def vocabulary_size(self):
245         return len(self.token2id)
246
247     def test_model(
248         self, n_epoch, model, primers_descr, nb_per_primer=1, generate_images=False
249     ):
250         nb_tokens_to_generate = self.height * self.width + 3
251         result_descr = []
252
253         for primer_descr in primers_descr:
254
255             results = autoregression(
256                 model,
257                 self.batch_size,
258                 nb_samples=nb_per_primer,
259                 nb_tokens_to_generate=nb_tokens_to_generate,
260                 primer=self.tensorize([primer_descr]).expand(nb_per_primer, -1),
261                 device=self.device,
262             )
263
264             l = [" ".join([self.id2token[t.item()] for t in r]) for r in results]
265             result_descr += l
266
267         np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
268
269         nb_requested_properties, _, nb_missing_properties = zip(*np)
270
271         log_string(
272             f"nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}"
273         )
274
275         np = torch.tensor(np)
276         count = torch.empty(np[:, 0].max() + 1, np[:, 2].max() + 1, dtype=torch.int64)
277         for i in range(count.size(0)):
278             for j in range(count.size(1)):
279                 count[i, j] = ((np[:, 0] == i).long() * (np[:, 2] == j).long()).sum()
280
281         if generate_images:
282             img = [
283                 picoclvr.descr2img(d, height=self.height, width=self.width)
284                 for d in result_descr
285             ]
286
287             img = torch.cat(img, 0)
288             image_name = f"result_picoclvr_{n_epoch:04d}.png"
289             torchvision.utils.save_image(
290                 img / 255.0, image_name, nrow=nb_per_primer, pad_value=0.8
291             )
292             log_string(f"wrote {image_name}")
293
294         return count
295
296     def produce_results(self, n_epoch, model):
297         primers_descr = [
298             "red above green <sep> green top <sep> blue right of red <img>",
299             "there is red <sep> there is yellow <sep> there is blue <img>",
300             "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>",
301             "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>",
302         ]
303
304         self.test_model(
305             n_epoch, model, primers_descr, nb_per_primer=8, generate_images=True
306         )
307
308         # FAR TOO SLOW!!!
309
310         # test_primers_descr=[ s.split('<img>')[0] for s in self.test_descr ]
311
312         # count=self.test_model(
313         # n_epoch, model,
314         # test_primers_descr,
315         # nb_per_primer=1, generate_images=False
316         # )
317
318         # with open(f'perf_{n_epoch:04d}.txt', 'w') as f:
319         # for i in range(count.size(0)):
320         # for j in range(count.size(1)):
321         # f.write(f'{count[i,j]}')
322         # f.write(" " if j<count.size(1)-1 else "\n")
323
324
325 ######################################################################
326
327
328 class TaskWiki103(Task):
329     def __init__(
330         self,
331         batch_size,
332         len_min=10,
333         len_max=200,
334         min_freq=100,
335         device=torch.device("cpu"),
336     ):
337
338         self.batch_size = batch_size
339         self.len_min = len_min
340         self.len_max = len_max
341         self.min_freq = min_freq
342         self.device = device
343
344         self.tokenizer = torchtext.data.get_tokenizer("basic_english")
345         train_iter = torchtext.datasets.WikiText103(split="train", root="./data/nlp/")
346
347         # Mostly for debug
348         if args.data_size is not None:
349             train_iter = itertools.islice(train_iter, args.data_size)
350
351         def yield_tokens():
352             for l in tqdm.tqdm(train_iter, desc="vocab"):
353                 yield self.tokenizer(l)
354
355         self.vocab = torchtext.vocab.build_vocab_from_iterator(
356             yield_tokens(), specials=["<unk>", "<nul>"], min_freq=self.min_freq
357         )
358
359         self.vocab.set_default_index(self.vocab["<unk>"])
360
361     # makes a tensor from a list of list of tokens
362     def tensorize(self, s):
363         a = max(len(x) for x in s)
364         return torch.tensor([self.vocab(x + ["<nul>"] * (a - len(x))) for x in s])
365
366     def yield_batches(self, ds):
367         s = []
368         for l in ds:
369             q = self.tokenizer(l)
370             if len(q) >= self.len_min and len(q) <= self.len_max:
371                 s += [q]
372                 if len(s) == self.batch_size:
373                     yield self.tensorize(s)
374                     s = []
375
376         if len(s) > 0:
377             yield self.tensorize(s)
378
379     def batches(self, split="train"):
380         data_iter = torchtext.datasets.WikiText103(split=split, root="./data/nlp/")
381
382         # Mostly for debug
383         if args.data_size is not None:
384             data_iter = itertools.islice(data_iter, args.data_size)
385
386         return self.yield_batches(tqdm.tqdm(data_iter, desc=f"epoch-{split}"))
387
388     def vocabulary_size(self):
389         return len(self.vocab)
390
391     def produce_results(self, n_epoch, model):
392         nb_tokens = 50
393         file_name = f"result_wiki103_{n_epoch:04d}.txt"
394
395         with open(file_name, "w") as outfile:
396             for primer in [
397                 "the cat is hunting a",
398                 "paris is the capital",
399                 "cars are convenient",
400                 "the difference between men and women is",
401                 "the object was blue all over and green all over it was",
402                 "cherries are red and lemons are",
403                 "cherries are sweet and lemons are",
404                 "two plus three equals",
405                 "deep learning is",
406             ]:
407                 t_primer = self.tokenizer(primer)
408                 t_generated = []
409
410                 for j in range(nb_tokens):
411
412                     input = self.tensorize([t_primer + t_generated]).to(self.device)
413                     input = F.pad(
414                         input, (0, 1)
415                     )  # Add the next token, the one to predict
416                     output = model(input)
417                     logits = output[0, -1]
418                     if args.deterministic_synthesis:
419                         t_next = logits.argmax()
420                     else:
421                         dist = torch.distributions.categorical.Categorical(
422                             logits=logits
423                         )
424                         t_next = dist.sample()
425                     t_generated.append(self.vocab.lookup_token(t_next))
426                     if t_generated[-1] == "<nul>":
427                         break
428
429                 s = " ".join(t_generated)
430
431                 outfile.write(f"<{primer}> {s}\n")
432
433         log_string(f"wrote {file_name}")
434
435
436 ######################################################################
437
438
439 class TaskMNIST(Task):
440     def __init__(self, batch_size, device=torch.device("cpu")):
441         self.device = device
442         self.batch_size = batch_size
443
444     def batches(self, split="train"):
445         assert split in {"train", "test"}
446         data_set = torchvision.datasets.MNIST(
447             root="./data", train=(split == "train"), download=True
448         )
449         data_input = data_set.data.view(-1, 28 * 28).long()
450         if args.data_size is not None:
451             data_input = data_input[: args.data_size]
452         for batch in tqdm.tqdm(
453             data_input.split(self.batch_size), desc=f"epoch-{split}"
454         ):
455             yield batch
456
457     def vocabulary_size(self):
458         return 256
459
460     def produce_results(self, n_epoch, model):
461         nb_samples = 64
462         results = autoregression(
463             model, self.batch_size, nb_samples, 28 * 28, device=self.device
464         )
465         image_name = f"result_mnist_{n_epoch:04d}.png"
466         torchvision.utils.save_image(
467             1 - results.reshape(-1, 1, 28, 28) / 255.0,
468             image_name,
469             nrow=16,
470             pad_value=0.8,
471         )
472         log_string(f"wrote {image_name}")
473
474
475 ######################################################################
476
477 log_string(f"device {device}")
478
479 if args.data == "wiki103":
480     task = TaskWiki103(batch_size=args.batch_size, device=device)
481 elif args.data in {"mnist", "mnist-debug"}:
482     task = TaskMNIST(batch_size=args.batch_size, device=device)
483 elif args.data == "picoclvr":
484     task = TaskPicoCLVR(
485         batch_size=args.batch_size,
486         height=args.picoclvr_height,
487         width=args.picoclvr_width,
488         nb_colors=args.picoclvr_nb_colors,
489         device=device,
490     )
491 else:
492     raise ValueError(f"Unknown dataset {args.data}.")
493
494 vocabulary_size = task.vocabulary_size()
495
496 log_string(f"vocabulary_size {vocabulary_size}")
497
498 ##############################
499
500 model = mygpt.MyGPT(
501     vocabulary_size=vocabulary_size,
502     dim_model=args.dim_model,
503     dim_keys=args.dim_keys,
504     dim_hidden=args.dim_hidden,
505     nb_heads=args.nb_heads,
506     nb_blocks=args.nb_blocks,
507     dropout=args.dropout,
508 )
509
510 model.to(device)
511
512 nb_parameters = sum(p.numel() for p in model.parameters())
513 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
514
515 ######################################################################
516
517 nb_epochs_finished = 0
518
519 if args.no_checkpoint:
520     log_string(f"not trying to load checkpoint.")
521
522 else:
523     try:
524         checkpoint = torch.load(args.checkpoint_name)
525         nb_epochs_finished = checkpoint["nb_epochs_finished"]
526         model.load_state_dict(checkpoint["model_state"])
527         torch.set_rng_state(checkpoint["rng_state"])
528         if torch.cuda.is_available():
529             torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
530         log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
531
532     except FileNotFoundError:
533         log_string("starting from scratch.")
534
535     except:
536         log_string("error when loading the checkpoint.")
537         exit(1)
538
539 ######################################################################
540
541 token_count = 0
542 for input in task.batches(split="train"):
543     token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
544 token_probas = token_count / token_count.sum()
545 entropy = -torch.xlogy(token_probas, token_probas).sum()
546 train_set_perplexity = math.exp(entropy)
547
548 for n_epoch in range(nb_epochs_finished, args.nb_epochs):
549
550     if args.learning_rate_end < 0:
551         lr = args.learning_rate
552     else:
553         u = n_epoch / (args.nb_epochs - 1)
554         lr = math.exp(
555             (1 - u) * math.log(args.learning_rate)
556             + u * math.log(args.learning_rate_end)
557         )
558         log_string(f"learning_rate {lr}")
559
560     if args.optim == "sgd":
561         optimizer = torch.optim.SGD(model.parameters(), lr=lr)
562     elif args.optim == "adam":
563         optimizer = torch.optim.Adam(model.parameters(), lr=lr)
564     elif args.optim == "adamw":
565         optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
566     else:
567         raise ValueError(f"Unknown optimizer {args.optim}.")
568
569     model.train()
570
571     nb_train_samples, acc_train_loss = 0, 0.0
572
573     for input in task.batches(split="train"):
574         input = input.to(device)
575         output = model(input)
576         loss = F.cross_entropy(output.transpose(1, 2), input)
577         acc_train_loss += loss.item() * input.size(0)
578         nb_train_samples += input.size(0)
579
580         optimizer.zero_grad()
581         loss.backward()
582         optimizer.step()
583
584     with torch.autograd.no_grad():
585
586         model.eval()
587
588         nb_test_samples, acc_test_loss = 0, 0.0
589
590         for input in task.batches(split="test"):
591             input = input.to(device)
592             output = model(input)
593             loss = F.cross_entropy(output.transpose(1, 2), input)
594             acc_test_loss += loss.item() * input.size(0)
595             nb_test_samples += input.size(0)
596
597         train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
598         test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
599
600         log_string(
601             f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
602         )
603
604         task.produce_results(n_epoch, model)
605
606     checkpoint = {
607         "nb_epochs_finished": n_epoch + 1,
608         "model_state": model.state_dict(),
609         "rng_state": torch.get_rng_state(),
610     }
611
612     if torch.cuda.is_available():
613         checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
614
615     torch.save(checkpoint, args.checkpoint_name)
616
617 ######################################################################