3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
24 parser.add_argument('--log_filename',
25 type = str, default = 'train.log')
27 parser.add_argument('--download',
28 action='store_true', default = False)
30 parser.add_argument('--seed',
31 type = int, default = 0)
33 parser.add_argument('--nb_epochs',
34 type = int, default = -1)
36 parser.add_argument('--batch_size',
37 type = int, default = 25)
39 parser.add_argument('--data',
40 type = str, default = 'wiki103')
42 parser.add_argument('--data_size',
43 type = int, default = -1)
45 parser.add_argument('--optim',
46 type = str, default = 'adam')
48 parser.add_argument('--learning_rate',
49 type = float, default = 1e-4)
51 parser.add_argument('--dim_model',
52 type = int, default = 512)
54 parser.add_argument('--dim_keys',
55 type = int, default = 64)
57 parser.add_argument('--dim_hidden',
58 type = int, default = 2048)
60 parser.add_argument('--nb_heads',
61 type = int, default = 8)
63 parser.add_argument('--nb_blocks',
64 type = int, default = 12)
66 parser.add_argument('--dropout',
67 type = float, default = 0.1)
69 parser.add_argument('--synthesis_sampling',
70 action='store_true', default = True)
72 parser.add_argument('--no_checkpoint',
73 action='store_true', default = False)
75 parser.add_argument('--checkpoint_name',
76 type = str, default = 'checkpoint.pth')
78 ##############################
81 parser.add_argument('--picoclvr_nb_colors',
82 type = int, default = 5)
84 parser.add_argument('--picoclvr_height',
85 type = int, default = 12)
87 parser.add_argument('--picoclvr_width',
88 type = int, default = 16)
90 ######################################################################
92 args = parser.parse_args()
94 log_file = open(args.log_filename, 'w')
97 torch.manual_seed(args.seed)
99 ######################################################################
102 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
104 if log_file is not None:
105 log_file.write(t + s + '\n')
112 log_string(f'args.{n} {getattr(args, n)}')
114 ######################################################################
118 nb_samples, nb_tokens_to_generate, starting_input = None,
119 device = torch.device('cpu')
121 results = torch.zeros(
122 nb_samples, nb_tokens_to_generate,
123 dtype = torch.int64, device = device
126 if starting_input is None:
129 first = starting_input.size(1)
130 results = torch.cat((starting_input, results), 1)
132 for input in results.split(args.batch_size):
133 for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'):
134 output = model(input)
135 logits = output[:, s]
136 if args.synthesis_sampling:
137 dist = torch.distributions.categorical.Categorical(logits = logits)
138 t_next = dist.sample()
140 t_next = logits.argmax(1)
145 ######################################################################
148 def batches(self, split = 'train'):
151 def vocabulary_size(self):
154 def produce_results(self, n_epoch, model, nb_tokens = 50):
157 ######################################################################
161 class TaskPicoCLVR(Task):
163 def __init__(self, batch_size,
164 height, width, nb_colors = 5,
165 device = torch.device('cpu')):
167 def generate_descr(nb):
168 descr = picoclvr.generate(
170 height = self.height, width = self.width,
171 nb_colors = nb_colors
174 descr = [ s.strip().split(' ') for s in descr ]
175 l = max([ len(s) for s in descr ])
176 descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
182 self.batch_size = batch_size
184 nb = args.data_size if args.data_size > 0 else 250000
186 self.train_descr = generate_descr((nb * 4) // 5)
187 self.test_descr = generate_descr((nb * 1) // 5)
189 # Build the tokenizer
191 for d in [ self.train_descr, self.test_descr ]:
193 for t in s: tokens.add(t)
194 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
195 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
197 t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ]
198 self.train_input = torch.tensor(t, device = self.device)
199 t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ]
200 self.test_input = torch.tensor(t, device = self.device)
202 def batches(self, split = 'train'):
203 assert split in { 'train', 'test' }
205 for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
208 for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
211 def vocabulary_size(self):
212 return len(self.token2id)
214 def generate(self, primer, model, nb_tokens):
215 t_primer = primer.strip().split(' ')
218 for j in range(nb_tokens):
219 t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
220 input = torch.tensor(t, device = self.device)
221 input = F.pad(input, (0, 1)) # Add the next token, the one to predict
222 output = model(input)
223 logits = output[0, -1]
224 if args.synthesis_sampling:
225 dist = torch.distributions.categorical.Categorical(logits = logits)
226 t_next = dist.sample()
228 t_next = logits.argmax()
229 t_generated.append(self.id2token[t_next.item()])
231 return ' '.join(t_primer + t_generated)
233 def produce_results(self, n_epoch, model, nb_tokens = None):
234 if nb_tokens is None:
235 nb_tokens = self.height * self.width + 3
240 'red above green <sep> green top <sep> blue right of red <img>',
241 'there is red <sep> there is yellow <sep> there is blue <img>',
242 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
243 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
246 for k in range(nb_per_primer):
247 descr.append(self.generate(primer, model, nb_tokens))
249 img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ]
250 img = torch.cat(img, 0)
251 image_name = f'result_picoclvr_{n_epoch:04d}.png'
252 torchvision.utils.save_image(
254 image_name, nrow = nb_per_primer, pad_value = 0.8
256 log_string(f'wrote {image_name}')
259 x[2] for x in picoclvr.nb_missing_properties(
261 height = self.height, width = self.width
265 log_string(f'nb_missing {nb_missing / len(descr):.02f}')
267 ######################################################################
269 class TaskWiki103(Task):
271 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
272 device = torch.device('cpu')):
274 self.batch_size = batch_size
275 self.len_min = len_min
276 self.len_max = len_max
277 self.min_freq = min_freq
280 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
281 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
284 if args.data_size > 0:
285 train_iter = itertools.islice(train_iter, args.data_size)
288 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
289 yield self.tokenizer(l)
291 self.vocab = torchtext.vocab.build_vocab_from_iterator(
293 specials = [ '<unk>', '<non>' ],
294 min_freq = self.min_freq
297 self.vocab.set_default_index(self.vocab[ '<unk>' ])
299 def tensorize(self, s):
300 a = max(len(x) for x in s)
301 return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
303 def yield_batches(self, ds):
306 q = self.tokenizer(l)
307 if len(q) >= self.len_min and len(q) <= self.len_max:
309 if len(s) == self.batch_size:
310 yield self.tensorize(s)
314 yield self.tensorize(s)
316 def batches(self, split = 'train'):
317 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
320 if args.data_size > 0:
321 data_iter = itertools.islice(data_iter, args.data_size)
323 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
325 def vocabulary_size(self):
326 return len(self.vocab)
328 def produce_results(self, n_epoch, model, nb_tokens = 50):
329 file_name = f'result_wiki103_{n_epoch:04d}.txt'
331 with open(file_name, 'w') as outfile:
333 'the cat is hunting a',
334 'paris is the capital',
335 'cars are convenient',
336 'the difference between men and women is',
337 'the object was blue all over and green all over it was',
338 'cherries are red and lemons are',
339 'cherries are sweet and lemons are',
340 'two plus three equals',
343 t_primer = self.tokenizer(primer)
346 for j in range(nb_tokens):
348 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
349 input = F.pad(input, (0, 1)) # Add the next token, the one to predict
350 output = model(input)
351 logits = output[0, -1]
352 if args.synthesis_sampling:
353 dist = torch.distributions.categorical.Categorical(logits = logits)
354 t_next = dist.sample()
356 t_next = logits.argmax()
357 t_generated.append(self.vocab.lookup_token(t_next))
358 if t_generated[-1] == '<non>': break
360 s = ' '.join(t_generated)
362 outfile.write(f'<{primer}> {s}\n')
364 log_string(f'wrote {file_name}')
366 ######################################################################
368 class TaskMNIST(Task):
370 def __init__(self, batch_size, device = torch.device('cpu')):
372 self.batch_size = batch_size
374 def batches(self, split = 'train'):
375 assert split in { 'train', 'test' }
376 data_set = torchvision.datasets.MNIST(
377 root = './data', train = (split == 'train'),
380 data_input = data_set.data.view(-1, 28 * 28).long()
381 if args.data_size >= 0:
382 data_input = data_input[:args.data_size]
383 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
386 def vocabulary_size(self):
389 def produce_results(self, n_epoch, model, nb_samples = 64):
390 results = autoregression(model, nb_samples, 28 * 28, device = self.device)
391 image_name = f'result_mnist_{n_epoch:04d}.png'
392 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
393 image_name, nrow = 16, pad_value = 0.8)
394 log_string(f'wrote {image_name}')
396 ######################################################################
398 log_string(f'device {device}')
400 if args.data == 'wiki103':
401 nb_epochs_default = 10
402 task = TaskWiki103(batch_size = args.batch_size, device = device)
403 elif args.data == 'mnist':
404 nb_epochs_default = 25
405 task = TaskMNIST(batch_size = args.batch_size, device = device)
406 elif args.data == 'picoclvr':
407 nb_epochs_default = 10
408 task = TaskPicoCLVR(batch_size = args.batch_size,
409 height = args.picoclvr_height,
410 width = args.picoclvr_width,
411 nb_colors = args.picoclvr_nb_colors,
414 raise ValueError(f'Unknown dataset {args.data}.')
416 vocabulary_size = task.vocabulary_size()
418 log_string(f'vocabulary_size {vocabulary_size}')
420 ##############################
423 vocabulary_size = vocabulary_size,
424 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
425 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
430 nb_parameters = sum(p.numel() for p in model.parameters())
431 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
433 ######################################################################
435 if args.optim == 'sgd':
436 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
437 elif args.optim == 'adam':
438 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
439 elif args.optim == 'adamw':
440 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
442 raise ValueError(f'Unknown optimizer {args.optim}.')
444 ######################################################################
446 nb_epochs_finished = 0
448 if args.no_checkpoint:
449 log_string(f'Not trying to load checkpoint.')
453 checkpoint = torch.load(args.checkpoint_name, map_location = device)
454 nb_epochs_finished = checkpoint['nb_epochs_finished']
455 model.load_state_dict(checkpoint['model_state'])
456 optimizer.load_state_dict(checkpoint['optimizer_state'])
457 log_string(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
459 except FileNotFoundError:
460 log_string('Starting from scratch.')
463 log_string('Error when loading the checkpoint.')
466 ######################################################################
468 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
471 for input in task.batches(split = 'train'):
472 token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
473 token_probas = token_count / token_count.sum()
474 h = -torch.xlogy(token_probas, token_probas).sum()
475 train_set_perplexity = math.exp(h)
476 log_string(f'Train set perplexity {train_set_perplexity}')
478 for k in range(nb_epochs_finished, nb_epochs):
482 nb_train_samples, acc_train_loss = 0, 0.0
484 for input in task.batches(split = 'train'):
485 input = input.to(device)
486 output = model(input)
487 loss = F.cross_entropy(output.transpose(1, 2), input)
488 acc_train_loss += loss.item() * input.size(0)
489 nb_train_samples += input.size(0)
491 optimizer.zero_grad()
495 with torch.autograd.no_grad():
499 nb_test_samples, acc_test_loss = 0, 0.0
501 for input in task.batches(split = 'test'):
502 input = input.to(device)
503 output = model(input)
504 loss = F.cross_entropy(output.transpose(1, 2), input)
505 acc_test_loss += loss.item() * input.size(0)
506 nb_test_samples += input.size(0)
508 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
509 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
511 log_string(f'perplexity {k} train {train_perplexity} test {test_perplexity}')
513 task.produce_results(k, model)
516 'nb_epochs_finished': k + 1,
517 'model_state': model.state_dict(),
518 'optimizer_state': optimizer.state_dict()
521 torch.save(checkpoint, args.checkpoint_name)
523 ######################################################################