3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
24 parser.add_argument('--log_filename',
25 type = str, default = 'train.log')
27 parser.add_argument('--seed',
28 type = int, default = 0)
30 parser.add_argument('--nb_epochs',
31 type = int, default = -1)
33 parser.add_argument('--batch_size',
34 type = int, default = 25)
36 parser.add_argument('--data',
37 type = str, default = 'wiki103')
39 parser.add_argument('--data_size',
40 type = int, default = -1)
42 parser.add_argument('--optim',
43 type = str, default = 'adam')
45 parser.add_argument('--learning_rate',
46 type = float, default = 1e-4)
48 parser.add_argument('--dim_model',
49 type = int, default = 512)
51 parser.add_argument('--dim_keys',
52 type = int, default = 64)
54 parser.add_argument('--dim_hidden',
55 type = int, default = 2048)
57 parser.add_argument('--nb_heads',
58 type = int, default = 8)
60 parser.add_argument('--nb_blocks',
61 type = int, default = 12)
63 parser.add_argument('--dropout',
64 type = float, default = 0.1)
66 parser.add_argument('--synthesis_sampling',
67 action='store_true', default = True)
69 parser.add_argument('--no_checkpoint',
70 action='store_true', default = False)
72 parser.add_argument('--checkpoint_name',
73 type = str, default = 'checkpoint.pth')
75 ##############################
78 parser.add_argument('--picoclvr_nb_colors',
79 type = int, default = 5)
81 parser.add_argument('--picoclvr_height',
82 type = int, default = 12)
84 parser.add_argument('--picoclvr_width',
85 type = int, default = 16)
87 ######################################################################
89 args = parser.parse_args()
91 log_file = open(args.log_filename, 'w')
94 torch.manual_seed(args.seed)
96 ######################################################################
99 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
101 if log_file is not None:
102 log_file.write(t + s + '\n')
109 log_string(f'args.{n} {getattr(args, n)}')
111 ######################################################################
115 nb_samples, nb_tokens_to_generate, starting_input = None,
116 device = torch.device('cpu')
118 results = torch.zeros(
119 nb_samples, nb_tokens_to_generate,
120 dtype = torch.int64, device = device
123 if starting_input is None:
126 first = starting_input.size(1)
127 results = torch.cat((starting_input, results), 1)
129 for input in results.split(args.batch_size):
130 for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'):
131 output = model(input)
132 logits = output[:, s]
133 if args.synthesis_sampling:
134 dist = torch.distributions.categorical.Categorical(logits = logits)
135 t_next = dist.sample()
137 t_next = logits.argmax(1)
142 ######################################################################
145 def batches(self, split = 'train'):
148 def vocabulary_size(self):
151 def produce_results(self, n_epoch, model, nb_tokens = 50):
154 ######################################################################
158 class TaskPicoCLVR(Task):
160 def __init__(self, batch_size,
161 height, width, nb_colors = 5,
162 device = torch.device('cpu')):
164 def generate_descr(nb):
165 descr = picoclvr.generate(
167 height = self.height, width = self.width,
168 nb_colors = nb_colors
171 descr = [ s.strip().split(' ') for s in descr ]
172 l = max([ len(s) for s in descr ])
173 descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
179 self.batch_size = batch_size
181 nb = args.data_size if args.data_size > 0 else 250000
183 self.train_descr = generate_descr((nb * 4) // 5)
184 self.test_descr = generate_descr((nb * 1) // 5)
186 # Build the tokenizer
188 for d in [ self.train_descr, self.test_descr ]:
190 for t in s: tokens.add(t)
191 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
192 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
194 t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ]
195 self.train_input = torch.tensor(t, device = self.device)
196 t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ]
197 self.test_input = torch.tensor(t, device = self.device)
199 def batches(self, split = 'train'):
200 assert split in { 'train', 'test' }
202 for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
205 for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
208 def vocabulary_size(self):
209 return len(self.token2id)
211 def generate(self, primer, model, nb_tokens):
212 t_primer = primer.strip().split(' ')
215 for j in range(nb_tokens):
216 t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
217 input = torch.tensor(t, device = self.device)
218 input = F.pad(input, (0, 1)) # Add the next token, the one to predict
219 output = model(input)
220 logits = output[0, -1]
221 if args.synthesis_sampling:
222 dist = torch.distributions.categorical.Categorical(logits = logits)
223 t_next = dist.sample()
225 t_next = logits.argmax()
226 t_generated.append(self.id2token[t_next.item()])
228 return ' '.join(t_primer + t_generated)
230 def produce_results(self, n_epoch, model, nb_tokens = None):
231 if nb_tokens is None:
232 nb_tokens = self.height * self.width + 3
237 'red above green <sep> green top <sep> blue right of red <img>',
238 'there is red <sep> there is yellow <sep> there is blue <img>',
239 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
240 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
243 for k in range(nb_per_primer):
244 descr.append(self.generate(primer, model, nb_tokens))
246 img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ]
247 img = torch.cat(img, 0)
248 image_name = f'result_picoclvr_{n_epoch:04d}.png'
249 torchvision.utils.save_image(
251 image_name, nrow = nb_per_primer, pad_value = 0.8
253 log_string(f'wrote {image_name}')
256 x[2] for x in picoclvr.nb_missing_properties(
258 height = self.height, width = self.width
262 log_string(f'nb_missing {nb_missing / len(descr):.02f}')
264 ######################################################################
266 class TaskWiki103(Task):
268 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
269 device = torch.device('cpu')):
271 self.batch_size = batch_size
272 self.len_min = len_min
273 self.len_max = len_max
274 self.min_freq = min_freq
277 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
278 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
281 if args.data_size > 0:
282 train_iter = itertools.islice(train_iter, args.data_size)
285 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
286 yield self.tokenizer(l)
288 self.vocab = torchtext.vocab.build_vocab_from_iterator(
290 specials = [ '<unk>', '<non>' ],
291 min_freq = self.min_freq
294 self.vocab.set_default_index(self.vocab[ '<unk>' ])
296 def tensorize(self, s):
297 a = max(len(x) for x in s)
298 return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
300 def yield_batches(self, ds):
303 q = self.tokenizer(l)
304 if len(q) >= self.len_min and len(q) <= self.len_max:
306 if len(s) == self.batch_size:
307 yield self.tensorize(s)
311 yield self.tensorize(s)
313 def batches(self, split = 'train'):
314 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
317 if args.data_size > 0:
318 data_iter = itertools.islice(data_iter, args.data_size)
320 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
322 def vocabulary_size(self):
323 return len(self.vocab)
325 def produce_results(self, n_epoch, model, nb_tokens = 50):
326 file_name = f'result_wiki103_{n_epoch:04d}.txt'
328 with open(file_name, 'w') as outfile:
330 'the cat is hunting a',
331 'paris is the capital',
332 'cars are convenient',
333 'the difference between men and women is',
334 'the object was blue all over and green all over it was',
335 'cherries are red and lemons are',
336 'cherries are sweet and lemons are',
337 'two plus three equals',
340 t_primer = self.tokenizer(primer)
343 for j in range(nb_tokens):
345 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
346 input = F.pad(input, (0, 1)) # Add the next token, the one to predict
347 output = model(input)
348 logits = output[0, -1]
349 if args.synthesis_sampling:
350 dist = torch.distributions.categorical.Categorical(logits = logits)
351 t_next = dist.sample()
353 t_next = logits.argmax()
354 t_generated.append(self.vocab.lookup_token(t_next))
355 if t_generated[-1] == '<non>': break
357 s = ' '.join(t_generated)
359 outfile.write(f'<{primer}> {s}\n')
361 log_string(f'wrote {file_name}')
363 ######################################################################
365 class TaskMNIST(Task):
367 def __init__(self, batch_size, device = torch.device('cpu')):
369 self.batch_size = batch_size
371 def batches(self, split = 'train'):
372 assert split in { 'train', 'test' }
373 data_set = torchvision.datasets.MNIST(
374 root = './data', train = (split == 'train'),
377 data_input = data_set.data.view(-1, 28 * 28).long()
378 if args.data_size >= 0:
379 data_input = data_input[:args.data_size]
380 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
383 def vocabulary_size(self):
386 def produce_results(self, n_epoch, model, nb_samples = 64):
387 results = autoregression(model, nb_samples, 28 * 28, device = self.device)
388 image_name = f'result_mnist_{n_epoch:04d}.png'
389 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
390 image_name, nrow = 16, pad_value = 0.8)
391 log_string(f'wrote {image_name}')
393 ######################################################################
395 log_string(f'device {device}')
397 if args.data == 'wiki103':
398 nb_epochs_default = 10
399 task = TaskWiki103(batch_size = args.batch_size, device = device)
400 elif args.data == 'mnist':
401 nb_epochs_default = 25
402 task = TaskMNIST(batch_size = args.batch_size, device = device)
403 elif args.data == 'picoclvr':
404 nb_epochs_default = 10
405 task = TaskPicoCLVR(batch_size = args.batch_size,
406 height = args.picoclvr_height,
407 width = args.picoclvr_width,
408 nb_colors = args.picoclvr_nb_colors,
411 raise ValueError(f'Unknown dataset {args.data}.')
413 vocabulary_size = task.vocabulary_size()
415 log_string(f'vocabulary_size {vocabulary_size}')
417 ##############################
420 vocabulary_size = vocabulary_size,
421 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
422 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
427 nb_parameters = sum(p.numel() for p in model.parameters())
428 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
430 ######################################################################
432 if args.optim == 'sgd':
433 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
434 elif args.optim == 'adam':
435 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
436 elif args.optim == 'adamw':
437 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
439 raise ValueError(f'Unknown optimizer {args.optim}.')
441 ######################################################################
443 nb_epochs_finished = 0
445 if args.no_checkpoint:
446 log_string(f'not trying to load checkpoint.')
450 checkpoint = torch.load(args.checkpoint_name, map_location = device)
451 nb_epochs_finished = checkpoint['nb_epochs_finished']
452 model.load_state_dict(checkpoint['model_state'])
453 optimizer.load_state_dict(checkpoint['optimizer_state'])
454 log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
456 except FileNotFoundError:
457 log_string('starting from scratch.')
460 log_string('error when loading the checkpoint.')
463 ######################################################################
465 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
468 for input in task.batches(split = 'train'):
469 token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
470 token_probas = token_count / token_count.sum()
471 h = -torch.xlogy(token_probas, token_probas).sum()
472 train_set_perplexity = math.exp(h)
473 log_string(f'train set perplexity {train_set_perplexity}')
475 for k in range(nb_epochs_finished, nb_epochs):
479 nb_train_samples, acc_train_loss = 0, 0.0
481 for input in task.batches(split = 'train'):
482 input = input.to(device)
483 output = model(input)
484 loss = F.cross_entropy(output.transpose(1, 2), input)
485 acc_train_loss += loss.item() * input.size(0)
486 nb_train_samples += input.size(0)
488 optimizer.zero_grad()
492 with torch.autograd.no_grad():
496 nb_test_samples, acc_test_loss = 0, 0.0
498 for input in task.batches(split = 'test'):
499 input = input.to(device)
500 output = model(input)
501 loss = F.cross_entropy(output.transpose(1, 2), input)
502 acc_test_loss += loss.item() * input.size(0)
503 nb_test_samples += input.size(0)
505 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
506 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
508 log_string(f'perplexity {k} train {train_perplexity} test {test_perplexity}')
510 task.produce_results(k, model)
513 'nb_epochs_finished': k + 1,
514 'model_state': model.state_dict(),
515 'optimizer_state': optimizer.state_dict()
518 torch.save(checkpoint, args.checkpoint_name)
520 ######################################################################