3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
24 parser.add_argument('--log_filename',
25 type = str, default = 'train.log')
27 parser.add_argument('--download',
28 action='store_true', default = False)
30 parser.add_argument('--seed',
31 type = int, default = 0)
33 parser.add_argument('--nb_epochs',
34 type = int, default = -1)
36 parser.add_argument('--batch_size',
37 type = int, default = 25)
39 parser.add_argument('--data',
40 type = str, default = 'wiki103')
42 parser.add_argument('--data_size',
43 type = int, default = -1)
45 parser.add_argument('--optim',
46 type = str, default = 'adam')
48 parser.add_argument('--learning_rate',
49 type = float, default = 1e-4)
51 parser.add_argument('--dim_model',
52 type = int, default = 512)
54 parser.add_argument('--dim_keys',
55 type = int, default = 64)
57 parser.add_argument('--dim_hidden',
58 type = int, default = 2048)
60 parser.add_argument('--nb_heads',
61 type = int, default = 8)
63 parser.add_argument('--nb_blocks',
64 type = int, default = 12)
66 parser.add_argument('--dropout',
67 type = float, default = 0.1)
69 parser.add_argument('--synthesis_sampling',
70 action='store_true', default = True)
72 parser.add_argument('--no_checkpoint',
73 action='store_true', default = False)
75 parser.add_argument('--checkpoint_name',
76 type = str, default = 'checkpoint.pth')
78 ##############################
81 parser.add_argument('--picoclvr_many_colors',
82 action='store_true', default = False)
84 parser.add_argument('--picoclvr_height',
85 type = int, default = 12)
87 parser.add_argument('--picoclvr_width',
88 type = int, default = 16)
90 ######################################################################
92 args = parser.parse_args()
94 log_file = open(args.log_filename, 'w')
97 torch.manual_seed(args.seed)
99 ######################################################################
102 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
104 if log_file is not None:
105 log_file.write(t + s + '\n')
112 log_string(f'args.{n} {getattr(args, n)}')
114 ######################################################################
118 model, nb_samples, nb_tokens_to_generate, starting_input = None,
121 results = torch.zeros(nb_samples, nb_tokens_to_generate, dtype = torch.int64, device = device)
122 for input in results.split(self.batch_size):
123 for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
124 output = model(input)
125 logits = output[:, s]
126 if args.synthesis_sampling:
127 dist = torch.distributions.categorical.Categorical(logits = logits)
133 ######################################################################
136 def batches(self, split = 'train'):
139 def vocabulary_size(self):
142 def produce_results(self, n_epoch, model, nb_tokens = 50):
145 ######################################################################
149 class TaskPicoCLVR(Task):
151 def __init__(self, batch_size,
152 height, width, many_colors = False,
153 device = torch.device('cpu')):
155 def generate_descr(nb):
156 descr = picoclvr.generate(
158 height = self.height, width = self.width,
159 many_colors = many_colors
162 descr = [ s.strip().split(' ') for s in descr ]
163 l = max([ len(s) for s in descr ])
164 descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
170 self.batch_size = batch_size
172 nb = args.data_size if args.data_size > 0 else 250000
174 self.train_descr = generate_descr((nb * 4) // 5)
175 self.test_descr = generate_descr((nb * 1) // 5)
177 # Build the tokenizer
179 for d in [ self.train_descr, self.test_descr ]:
181 for t in s: tokens.add(t)
182 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
183 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
185 t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ]
186 self.train_input = torch.tensor(t, device = self.device)
187 t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ]
188 self.test_input = torch.tensor(t, device = self.device)
190 def batches(self, split = 'train'):
191 assert split in { 'train', 'test' }
193 for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
196 for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
199 def vocabulary_size(self):
200 return len(self.token2id)
202 def generate(self, primer, model, nb_tokens):
203 t_primer = primer.strip().split(' ')
206 for j in range(nb_tokens):
207 t = [ [ self.token2id[u] for u in t_primer + t_generated ] + [ 0 ] ]
208 input = torch.tensor(t, device = self.device)
209 output = model(input)
210 logits = output[0, -1]
211 if args.synthesis_sampling:
212 dist = torch.distributions.categorical.Categorical(logits = logits)
216 t_generated.append(self.id2token[t.item()])
218 return ' '.join(t_primer + t_generated)
220 def produce_results(self, n_epoch, model, nb_tokens = None):
221 if nb_tokens is None:
222 nb_tokens = self.height * self.width + 3
227 'red above green <sep> green top <sep> blue right of red <img>',
228 'there is red <sep> there is yellow <sep> there is blue <img>',
229 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
230 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
233 for k in range(nb_per_primer):
234 descr.append(self.generate(primer, model, nb_tokens))
236 img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ]
237 img = torch.cat(img, 0)
238 image_name = f'result_picoclvr_{n_epoch:04d}.png'
239 torchvision.utils.save_image(
241 image_name, nrow = nb_per_primer, pad_value = 0.8
243 log_string(f'wrote {image_name}')
246 x[2] for x in picoclvr.nb_missing_properties(
248 height = self.height, width = self.width
252 log_string(f'nb_missing {nb_missing / len(descr):.02f}')
254 ######################################################################
256 class TaskWiki103(Task):
258 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
259 device = torch.device('cpu')):
261 self.batch_size = batch_size
262 self.len_min = len_min
263 self.len_max = len_max
264 self.min_freq = min_freq
267 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
268 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
271 if args.data_size > 0:
272 train_iter = itertools.islice(train_iter, args.data_size)
275 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
276 yield self.tokenizer(l)
278 self.vocab = torchtext.vocab.build_vocab_from_iterator(
280 specials = [ '<unk>', '<non>' ],
281 min_freq = self.min_freq
284 self.vocab.set_default_index(self.vocab[ '<unk>' ])
286 def tensorize(self, s):
287 a = max(len(x) for x in s)
288 return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
290 def yield_batches(self, ds):
293 q = self.tokenizer(l)
294 if len(q) >= self.len_min and len(q) <= self.len_max:
296 if len(s) == self.batch_size:
297 yield self.tensorize(s)
301 yield self.tensorize(s)
303 def batches(self, split = 'train'):
304 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
307 if args.data_size > 0:
308 data_iter = itertools.islice(data_iter, args.data_size)
310 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
312 def vocabulary_size(self):
313 return len(self.vocab)
315 def produce_results(self, n_epoch, model, nb_tokens = 50):
316 file_name = f'result_wiki103_{n_epoch:04d}.txt'
318 with open(file_name, 'w') as outfile:
320 'the cat is hunting a',
321 'paris is the capital',
322 'cars are convenient',
323 'the difference between men and women is',
324 'the object was blue all over and green all over it was',
325 'cherries are red and lemons are',
326 'cherries are sweet and lemons are',
327 'two plus three equals',
330 t_primer = self.tokenizer(primer)
333 for j in range(nb_tokens):
335 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
336 output = model(input)
337 logits = output[0, -1]
338 if args.synthesis_sampling:
339 dist = torch.distributions.categorical.Categorical(logits = logits)
343 t_generated.append(self.vocab.lookup_token(t))
344 if t_generated[-1] == '<non>': break
346 s = ' '.join(t_generated)
348 outfile.write(f'<{primer}> {s}\n')
350 log_string(f'wrote {file_name}')
352 ######################################################################
354 class TaskMNIST(Task):
356 def __init__(self, batch_size, device = torch.device('cpu')):
358 self.batch_size = batch_size
360 def batches(self, split = 'train'):
361 assert split in { 'train', 'test' }
362 data_set = torchvision.datasets.MNIST(
363 root = './data', train = (split == 'train'),
366 data_input = data_set.data.view(-1, 28 * 28).long()
367 if args.data_size >= 0:
368 data_input = data_input[:args.data_size]
369 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
372 def vocabulary_size(self):
375 def produce_results(self, n_epoch, model, nb_samples = 64):
376 results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
377 for input in results.split(self.batch_size):
378 for s in tqdm.tqdm(range(input.size(1)), desc = 'synth'):
379 output = model(input)
380 logits = output[:, s]
381 if args.synthesis_sampling:
382 dist = torch.distributions.categorical.Categorical(logits = logits)
388 image_name = f'result_mnist_{n_epoch:04d}.png'
389 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
390 image_name, nrow = 16, pad_value = 0.8)
391 log_string(f'wrote {image_name}')
393 ######################################################################
395 log_string(f'device {device}')
397 if args.data == 'wiki103':
398 nb_epochs_default = 10
399 task = TaskWiki103(batch_size = args.batch_size, device = device)
400 elif args.data == 'mnist':
401 nb_epochs_default = 25
402 task = TaskMNIST(batch_size = args.batch_size, device = device)
403 elif args.data == 'picoclvr':
404 nb_epochs_default = 10
405 task = TaskPicoCLVR(batch_size = args.batch_size,
406 height = args.picoclvr_height,
407 width = args.picoclvr_width,
408 many_colors = args.picoclvr_many_colors,
411 raise ValueError(f'Unknown dataset {args.data}.')
413 vocabulary_size = task.vocabulary_size()
415 log_string(f'vocabulary_size {vocabulary_size}')
417 ##############################
420 vocabulary_size = vocabulary_size,
421 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
422 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
427 nb_parameters = sum(p.numel() for p in model.parameters())
428 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
430 ######################################################################
432 if args.optim == 'sgd':
433 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
434 elif args.optim == 'adam':
435 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
436 elif args.optim == 'adamw':
437 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
439 raise ValueError(f'Unknown optimizer {args.optim}.')
441 ######################################################################
443 nb_epochs_finished = 0
445 if args.no_checkpoint:
446 log_string(f'Not trying to load checkpoint.')
450 checkpoint = torch.load(args.checkpoint_name, map_location = device)
451 nb_epochs_finished = checkpoint['nb_epochs_finished']
452 model.load_state_dict(checkpoint['model_state'])
453 optimizer.load_state_dict(checkpoint['optimizer_state'])
454 log_string(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
456 except FileNotFoundError:
457 log_string('Starting from scratch.')
460 log_string('Error when loading the checkpoint.')
463 ######################################################################
465 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
468 for input in task.batches(split = 'train'):
469 token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
470 token_probas = token_count / token_count.sum()
471 h = -torch.xlogy(token_probas, token_probas).sum()
472 train_set_perplexity = math.exp(h)
473 log_string(f'Train set perplexity {train_set_perplexity}')
475 for k in range(nb_epochs_finished, nb_epochs):
479 nb_train_samples, acc_train_loss = 0, 0.0
481 for input in task.batches(split = 'train'):
482 input = input.to(device)
483 output = model(input)
484 loss = F.cross_entropy(output.transpose(1, 2), input)
485 acc_train_loss += loss.item() * input.size(0)
486 nb_train_samples += input.size(0)
488 optimizer.zero_grad()
492 with torch.autograd.no_grad():
496 nb_test_samples, acc_test_loss = 0, 0.0
498 for input in task.batches(split = 'test'):
499 input = input.to(device)
500 output = model(input)
501 loss = F.cross_entropy(output.transpose(1, 2), input)
502 acc_test_loss += loss.item() * input.size(0)
503 nb_test_samples += input.size(0)
505 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
506 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
508 log_string(f'perplexity {k} train {train_perplexity} test {test_perplexity}')
510 task.produce_results(k, model)
513 'nb_epochs_finished': k + 1,
514 'model_state': model.state_dict(),
515 'optimizer_state': optimizer.state_dict()
518 torch.save(checkpoint, args.checkpoint_name)
520 ######################################################################