3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
21 parser = argparse.ArgumentParser(description = 'My own GPT.')
23 parser.add_argument('--log_filename',
24 type = str, default = 'train.log')
26 parser.add_argument('--seed',
27 type = int, default = 0)
29 parser.add_argument('--nb_epochs',
30 type = int, default = -1)
32 parser.add_argument('--batch_size',
33 type = int, default = 25)
35 parser.add_argument('--data',
36 type = str, default = 'wiki103')
38 parser.add_argument('--data_size',
39 type = int, default = -1)
41 parser.add_argument('--optim',
42 type = str, default = 'adam')
44 parser.add_argument('--learning_rate',
45 type = float, default = 1e-4)
47 parser.add_argument('--dim_model',
48 type = int, default = 512)
50 parser.add_argument('--dim_keys',
51 type = int, default = 64)
53 parser.add_argument('--dim_hidden',
54 type = int, default = 2048)
56 parser.add_argument('--nb_heads',
57 type = int, default = 8)
59 parser.add_argument('--nb_blocks',
60 type = int, default = 12)
62 parser.add_argument('--dropout',
63 type = float, default = 0.1)
65 parser.add_argument('--synthesis_sampling',
66 action='store_true', default = True)
68 parser.add_argument('--no_checkpoint',
69 action='store_true', default = False)
71 parser.add_argument('--checkpoint_name',
72 type = str, default = 'checkpoint.pth')
74 ##############################
77 parser.add_argument('--picoclvr_nb_colors',
78 type = int, default = 5)
80 parser.add_argument('--picoclvr_height',
81 type = int, default = 12)
83 parser.add_argument('--picoclvr_width',
84 type = int, default = 16)
86 ######################################################################
88 args = parser.parse_args()
90 log_file = open(args.log_filename, 'w')
93 torch.manual_seed(args.seed)
95 ######################################################################
98 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
100 if log_file is not None:
101 log_file.write(t + s + '\n')
108 log_string(f'args.{n} {getattr(args, n)}')
110 ######################################################################
114 nb_samples, nb_tokens_to_generate, primer = None,
115 device = torch.device('cpu')
117 results = torch.zeros(
118 nb_samples, nb_tokens_to_generate,
119 dtype = torch.int64, device = device
125 first = primer.size(1)
126 results = torch.cat((primer, results), 1)
128 for input in results.split(batch_size):
129 for s in range(first, input.size(1)):
130 output = model(input)
131 logits = output[:, s]
132 if args.synthesis_sampling:
133 dist = torch.distributions.categorical.Categorical(logits = logits)
134 t_next = dist.sample()
136 t_next = logits.argmax(1)
141 ######################################################################
144 def batches(self, split = 'train'):
147 def vocabulary_size(self):
150 def produce_results(self, n_epoch, model):
153 ######################################################################
157 class TaskPicoCLVR(Task):
159 # Make a tensor from a list of strings
160 def tensorize(self, descr):
161 token_descr = [ s.strip().split(' ') for s in descr ]
162 l = max([ len(s) for s in token_descr ])
163 #token_descr = [ [ '<nul>' ] * (l - len(s)) + s for s in token_descr ]
164 token_descr = [ s + [ '<nul>' ] * (l - len(s)) for s in token_descr ]
165 id_descr = [ [ self.token2id[u] for u in s ] for s in token_descr ]
166 return torch.tensor(id_descr, device = self.device)
168 def trim(self, x, token = '<nul>'):
169 n = self.token2id[token]
170 i = (1 - (F.pad(x, (1, 1), value = n) == n).min(0).values.long()).cumsum(0)
171 a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
174 def __init__(self, batch_size,
175 height, width, nb_colors = 5,
176 device = torch.device('cpu')):
178 def generate_descr(nb):
179 return picoclvr.generate(
181 height = self.height, width = self.width,
182 nb_colors = nb_colors
187 self.batch_size = batch_size
189 nb = args.data_size if args.data_size > 0 else 250000
191 log_string('generating {nb} samples (can take some time)')
192 self.train_descr = generate_descr((nb * 4) // 5)
193 self.test_descr = generate_descr((nb * 1) // 5)
195 # Build the tokenizer
197 for d in [ self.train_descr, self.test_descr ]:
199 for t in s.strip().split(' '): tokens.add(t)
200 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
201 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
203 # Tokenize the train and test sets
204 self.train_input = self.tensorize(self.train_descr)
205 self.test_input = self.tensorize(self.test_descr)
207 def batches(self, split = 'train'):
208 assert split in { 'train', 'test' }
209 input = self.train_input if split == 'train' else self.test_input
210 for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'):
211 yield self.trim(batch)
213 def vocabulary_size(self):
214 return len(self.token2id)
216 def produce_results(self, n_epoch, model):
217 nb_tokens_to_generate = self.height * self.width + 3
221 for primer_descr in [
222 'red above green <sep> green top <sep> blue right of red <img>',
223 'there is red <sep> there is yellow <sep> there is blue <img>',
224 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
225 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
228 results = autoregression(
231 nb_samples = nb_per_primer,
232 nb_tokens_to_generate = nb_tokens_to_generate,
233 primer = self.tensorize([ primer_descr ]).expand(nb_per_primer, -1),
237 l = [ ' '.join([ self.id2token[t.item()] for t in r ]) for r in results ]
240 np = picoclvr.nb_properties(
242 height = self.height, width = self.width
245 nb_requested_properties, _, nb_missing_properties = zip(*np)
247 log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
250 picoclvr.descr2img(d, height = self.height, width = self.width)
251 for d in result_descr
254 img = torch.cat(img, 0)
255 image_name = f'result_picoclvr_{n_epoch:04d}.png'
256 torchvision.utils.save_image(
258 image_name, nrow = nb_per_primer, pad_value = 0.8
260 log_string(f'wrote {image_name}')
262 ######################################################################
264 class TaskWiki103(Task):
266 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
267 device = torch.device('cpu')):
269 self.batch_size = batch_size
270 self.len_min = len_min
271 self.len_max = len_max
272 self.min_freq = min_freq
275 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
276 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
279 if args.data_size > 0:
280 train_iter = itertools.islice(train_iter, args.data_size)
283 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
284 yield self.tokenizer(l)
286 self.vocab = torchtext.vocab.build_vocab_from_iterator(
288 specials = [ '<unk>', '<nul>' ],
289 min_freq = self.min_freq
292 self.vocab.set_default_index(self.vocab[ '<unk>' ])
294 # makes a tensor from a list of list of tokens
295 def tensorize(self, s):
296 a = max(len(x) for x in s)
297 return torch.tensor([ self.vocab(x + [ '<nul>' ] * (a - len(x))) for x in s ])
299 def yield_batches(self, ds):
302 q = self.tokenizer(l)
303 if len(q) >= self.len_min and len(q) <= self.len_max:
305 if len(s) == self.batch_size:
306 yield self.tensorize(s)
310 yield self.tensorize(s)
312 def batches(self, split = 'train'):
313 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
316 if args.data_size > 0:
317 data_iter = itertools.islice(data_iter, args.data_size)
319 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
321 def vocabulary_size(self):
322 return len(self.vocab)
324 def produce_results(self, n_epoch, model):
326 file_name = f'result_wiki103_{n_epoch:04d}.txt'
328 with open(file_name, 'w') as outfile:
330 'the cat is hunting a',
331 'paris is the capital',
332 'cars are convenient',
333 'the difference between men and women is',
334 'the object was blue all over and green all over it was',
335 'cherries are red and lemons are',
336 'cherries are sweet and lemons are',
337 'two plus three equals',
340 t_primer = self.tokenizer(primer)
343 for j in range(nb_tokens):
345 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
346 input = F.pad(input, (0, 1)) # Add the next token, the one to predict
347 output = model(input)
348 logits = output[0, -1]
349 if args.synthesis_sampling:
350 dist = torch.distributions.categorical.Categorical(logits = logits)
351 t_next = dist.sample()
353 t_next = logits.argmax()
354 t_generated.append(self.vocab.lookup_token(t_next))
355 if t_generated[-1] == '<nul>': break
357 s = ' '.join(t_generated)
359 outfile.write(f'<{primer}> {s}\n')
361 log_string(f'wrote {file_name}')
363 ######################################################################
365 class TaskMNIST(Task):
367 def __init__(self, batch_size, device = torch.device('cpu')):
369 self.batch_size = batch_size
371 def batches(self, split = 'train'):
372 assert split in { 'train', 'test' }
373 data_set = torchvision.datasets.MNIST(
374 root = './data', train = (split == 'train'),
377 data_input = data_set.data.view(-1, 28 * 28).long()
378 if args.data_size >= 0:
379 data_input = data_input[:args.data_size]
380 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
383 def vocabulary_size(self):
386 def produce_results(self, n_epoch, model):
388 results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
389 image_name = f'result_mnist_{n_epoch:04d}.png'
390 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
391 image_name, nrow = 16, pad_value = 0.8)
392 log_string(f'wrote {image_name}')
394 ######################################################################
396 log_string(f'device {device}')
398 if args.data == 'wiki103':
399 nb_epochs_default = 10
400 task = TaskWiki103(batch_size = args.batch_size, device = device)
401 elif args.data == 'mnist':
402 nb_epochs_default = 25
403 task = TaskMNIST(batch_size = args.batch_size, device = device)
404 elif args.data == 'picoclvr':
405 nb_epochs_default = 10
406 task = TaskPicoCLVR(batch_size = args.batch_size,
407 height = args.picoclvr_height,
408 width = args.picoclvr_width,
409 nb_colors = args.picoclvr_nb_colors,
412 raise ValueError(f'Unknown dataset {args.data}.')
414 vocabulary_size = task.vocabulary_size()
416 log_string(f'vocabulary_size {vocabulary_size}')
418 ##############################
421 vocabulary_size = vocabulary_size,
422 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
423 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
428 nb_parameters = sum(p.numel() for p in model.parameters())
429 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
431 ######################################################################
433 if args.optim == 'sgd':
434 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
435 elif args.optim == 'adam':
436 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
437 elif args.optim == 'adamw':
438 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
440 raise ValueError(f'Unknown optimizer {args.optim}.')
442 ######################################################################
444 nb_epochs_finished = 0
446 if args.no_checkpoint:
447 log_string(f'not trying to load checkpoint.')
451 checkpoint = torch.load(args.checkpoint_name, map_location = device)
452 nb_epochs_finished = checkpoint['nb_epochs_finished']
453 model.load_state_dict(checkpoint['model_state'])
454 optimizer.load_state_dict(checkpoint['optimizer_state'])
455 log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
457 except FileNotFoundError:
458 log_string('starting from scratch.')
461 log_string('error when loading the checkpoint.')
464 ######################################################################
466 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
469 for input in task.batches(split = 'train'):
470 token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
471 token_probas = token_count / token_count.sum()
472 entropy = -torch.xlogy(token_probas, token_probas).sum()
473 train_set_perplexity = math.exp(entropy)
475 for k in range(nb_epochs_finished, nb_epochs):
479 nb_train_samples, acc_train_loss = 0, 0.0
481 for input in task.batches(split = 'train'):
482 input = input.to(device)
483 output = model(input)
484 loss = F.cross_entropy(output.transpose(1, 2), input)
485 acc_train_loss += loss.item() * input.size(0)
486 nb_train_samples += input.size(0)
488 optimizer.zero_grad()
492 with torch.autograd.no_grad():
496 nb_test_samples, acc_test_loss = 0, 0.0
498 for input in task.batches(split = 'test'):
499 input = input.to(device)
500 output = model(input)
501 loss = F.cross_entropy(output.transpose(1, 2), input)
502 acc_test_loss += loss.item() * input.size(0)
503 nb_test_samples += input.size(0)
505 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
506 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
508 log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
510 task.produce_results(k, model)
513 'nb_epochs_finished': k + 1,
514 'model_state': model.state_dict(),
515 'optimizer_state': optimizer.state_dict()
518 torch.save(checkpoint, args.checkpoint_name)
520 ######################################################################