3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
21 parser = argparse.ArgumentParser(description = 'My own GPT.')
23 parser.add_argument('--log_filename',
24 type = str, default = 'train.log')
26 parser.add_argument('--seed',
27 type = int, default = 0)
29 parser.add_argument('--nb_epochs',
30 type = int, default = -1)
32 parser.add_argument('--batch_size',
33 type = int, default = 25)
35 parser.add_argument('--data',
36 type = str, default = 'wiki103')
38 parser.add_argument('--data_size',
39 type = int, default = -1)
41 parser.add_argument('--optim',
42 type = str, default = 'adam')
44 parser.add_argument('--learning_rate',
45 type = float, default = 1e-4)
47 parser.add_argument('--dim_model',
48 type = int, default = 512)
50 parser.add_argument('--dim_keys',
51 type = int, default = 64)
53 parser.add_argument('--dim_hidden',
54 type = int, default = 2048)
56 parser.add_argument('--nb_heads',
57 type = int, default = 8)
59 parser.add_argument('--nb_blocks',
60 type = int, default = 12)
62 parser.add_argument('--dropout',
63 type = float, default = 0.1)
65 parser.add_argument('--synthesis_sampling',
66 action='store_true', default = True)
68 parser.add_argument('--no_checkpoint',
69 action='store_true', default = False)
71 parser.add_argument('--checkpoint_name',
72 type = str, default = 'checkpoint.pth')
74 ##############################
77 parser.add_argument('--picoclvr_nb_colors',
78 type = int, default = 5)
80 parser.add_argument('--picoclvr_height',
81 type = int, default = 12)
83 parser.add_argument('--picoclvr_width',
84 type = int, default = 16)
86 ######################################################################
88 args = parser.parse_args()
90 log_file = open(args.log_filename, 'w')
93 torch.manual_seed(args.seed)
95 ######################################################################
98 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
100 if log_file is not None:
101 log_file.write(t + s + '\n')
108 log_string(f'args.{n} {getattr(args, n)}')
110 ######################################################################
114 nb_samples, nb_tokens_to_generate, primer = None,
115 device = torch.device('cpu')
117 results = torch.zeros(
118 nb_samples, nb_tokens_to_generate,
119 dtype = torch.int64, device = device
125 first = primer.size(1)
126 results = torch.cat((primer, results), 1)
128 for input in results.split(batch_size):
129 for s in range(first, input.size(1)):
130 output = model(input)
131 logits = output[:, s]
132 if args.synthesis_sampling:
133 dist = torch.distributions.categorical.Categorical(logits = logits)
134 t_next = dist.sample()
136 t_next = logits.argmax(1)
141 ######################################################################
144 def batches(self, split = 'train'):
147 def vocabulary_size(self):
150 def produce_results(self, n_epoch, model):
153 ######################################################################
157 class TaskPicoCLVR(Task):
159 # Make a tensor from a list of strings
160 def tensorize(self, descr):
161 token_descr = [ s.strip().split(' ') for s in descr ]
162 l = max([ len(s) for s in token_descr ])
163 #token_descr = [ [ '<nul>' ] * (l - len(s)) + s for s in token_descr ]
164 token_descr = [ s + [ '<nul>' ] * (l - len(s)) for s in token_descr ]
165 id_descr = [ [ self.token2id[u] for u in s ] for s in token_descr ]
166 return torch.tensor(id_descr, device = self.device)
168 def trim(self, x, token = '<nul>'):
169 n = self.token2id[token]
170 i = (1 - (F.pad(x, (1, 1), value = n) == n).min(0).values.long()).cumsum(0)
171 a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
174 def __init__(self, batch_size,
175 height, width, nb_colors = 5,
176 device = torch.device('cpu')):
178 def generate_descr(nb):
179 return picoclvr.generate(
181 height = self.height, width = self.width,
182 nb_colors = nb_colors
187 self.batch_size = batch_size
189 nb = args.data_size if args.data_size > 0 else 250000
191 self.train_descr = generate_descr((nb * 4) // 5)
192 self.test_descr = generate_descr((nb * 1) // 5)
194 # Build the tokenizer
196 for d in [ self.train_descr, self.test_descr ]:
198 for t in s.strip().split(' '): tokens.add(t)
199 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
200 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
202 # Tokenize the train and test sets
203 self.train_input = self.tensorize(self.train_descr)
204 self.test_input = self.tensorize(self.test_descr)
206 def batches(self, split = 'train'):
207 assert split in { 'train', 'test' }
208 input = self.train_input if split == 'train' else self.test_input
209 for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'):
210 yield self.trim(batch)
212 def vocabulary_size(self):
213 return len(self.token2id)
215 def produce_results(self, n_epoch, model):
216 nb_tokens_to_generate = self.height * self.width + 3
220 for primer_descr in [
221 'red above green <sep> green top <sep> blue right of red <img>',
222 'there is red <sep> there is yellow <sep> there is blue <img>',
223 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
224 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
227 results = autoregression(
230 nb_samples = nb_per_primer,
231 nb_tokens_to_generate = nb_tokens_to_generate,
232 primer = self.tensorize([ primer_descr ]).expand(nb_per_primer, -1),
236 l = [ ' '.join([ self.id2token[t.item()] for t in r ]) for r in results ]
239 np = picoclvr.nb_properties(
241 height = self.height, width = self.width
244 nb_requested_properties, _, nb_missing_properties = zip(*np)
246 log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
249 picoclvr.descr2img(d, height = self.height, width = self.width)
250 for d in result_descr
253 img = torch.cat(img, 0)
254 image_name = f'result_picoclvr_{n_epoch:04d}.png'
255 torchvision.utils.save_image(
257 image_name, nrow = nb_per_primer, pad_value = 0.8
259 log_string(f'wrote {image_name}')
261 ######################################################################
263 class TaskWiki103(Task):
265 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
266 device = torch.device('cpu')):
268 self.batch_size = batch_size
269 self.len_min = len_min
270 self.len_max = len_max
271 self.min_freq = min_freq
274 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
275 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
278 if args.data_size > 0:
279 train_iter = itertools.islice(train_iter, args.data_size)
282 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
283 yield self.tokenizer(l)
285 self.vocab = torchtext.vocab.build_vocab_from_iterator(
287 specials = [ '<unk>', '<nul>' ],
288 min_freq = self.min_freq
291 self.vocab.set_default_index(self.vocab[ '<unk>' ])
293 # makes a tensor from a list of list of tokens
294 def tensorize(self, s):
295 a = max(len(x) for x in s)
296 return torch.tensor([ self.vocab(x + [ '<nul>' ] * (a - len(x))) for x in s ])
298 def yield_batches(self, ds):
301 q = self.tokenizer(l)
302 if len(q) >= self.len_min and len(q) <= self.len_max:
304 if len(s) == self.batch_size:
305 yield self.tensorize(s)
309 yield self.tensorize(s)
311 def batches(self, split = 'train'):
312 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
315 if args.data_size > 0:
316 data_iter = itertools.islice(data_iter, args.data_size)
318 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
320 def vocabulary_size(self):
321 return len(self.vocab)
323 def produce_results(self, n_epoch, model):
325 file_name = f'result_wiki103_{n_epoch:04d}.txt'
327 with open(file_name, 'w') as outfile:
329 'the cat is hunting a',
330 'paris is the capital',
331 'cars are convenient',
332 'the difference between men and women is',
333 'the object was blue all over and green all over it was',
334 'cherries are red and lemons are',
335 'cherries are sweet and lemons are',
336 'two plus three equals',
339 t_primer = self.tokenizer(primer)
342 for j in range(nb_tokens):
344 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
345 input = F.pad(input, (0, 1)) # Add the next token, the one to predict
346 output = model(input)
347 logits = output[0, -1]
348 if args.synthesis_sampling:
349 dist = torch.distributions.categorical.Categorical(logits = logits)
350 t_next = dist.sample()
352 t_next = logits.argmax()
353 t_generated.append(self.vocab.lookup_token(t_next))
354 if t_generated[-1] == '<nul>': break
356 s = ' '.join(t_generated)
358 outfile.write(f'<{primer}> {s}\n')
360 log_string(f'wrote {file_name}')
362 ######################################################################
364 class TaskMNIST(Task):
366 def __init__(self, batch_size, device = torch.device('cpu')):
368 self.batch_size = batch_size
370 def batches(self, split = 'train'):
371 assert split in { 'train', 'test' }
372 data_set = torchvision.datasets.MNIST(
373 root = './data', train = (split == 'train'),
376 data_input = data_set.data.view(-1, 28 * 28).long()
377 if args.data_size >= 0:
378 data_input = data_input[:args.data_size]
379 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
382 def vocabulary_size(self):
385 def produce_results(self, n_epoch, model):
387 results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
388 image_name = f'result_mnist_{n_epoch:04d}.png'
389 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
390 image_name, nrow = 16, pad_value = 0.8)
391 log_string(f'wrote {image_name}')
393 ######################################################################
395 log_string(f'device {device}')
397 if args.data == 'wiki103':
398 nb_epochs_default = 10
399 task = TaskWiki103(batch_size = args.batch_size, device = device)
400 elif args.data == 'mnist':
401 nb_epochs_default = 25
402 task = TaskMNIST(batch_size = args.batch_size, device = device)
403 elif args.data == 'picoclvr':
404 nb_epochs_default = 10
405 task = TaskPicoCLVR(batch_size = args.batch_size,
406 height = args.picoclvr_height,
407 width = args.picoclvr_width,
408 nb_colors = args.picoclvr_nb_colors,
411 raise ValueError(f'Unknown dataset {args.data}.')
413 vocabulary_size = task.vocabulary_size()
415 log_string(f'vocabulary_size {vocabulary_size}')
417 ##############################
420 vocabulary_size = vocabulary_size,
421 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
422 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
427 nb_parameters = sum(p.numel() for p in model.parameters())
428 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
430 ######################################################################
432 if args.optim == 'sgd':
433 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
434 elif args.optim == 'adam':
435 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
436 elif args.optim == 'adamw':
437 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
439 raise ValueError(f'Unknown optimizer {args.optim}.')
441 ######################################################################
443 nb_epochs_finished = 0
445 if args.no_checkpoint:
446 log_string(f'not trying to load checkpoint.')
450 checkpoint = torch.load(args.checkpoint_name, map_location = device)
451 nb_epochs_finished = checkpoint['nb_epochs_finished']
452 model.load_state_dict(checkpoint['model_state'])
453 optimizer.load_state_dict(checkpoint['optimizer_state'])
454 log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
456 except FileNotFoundError:
457 log_string('starting from scratch.')
460 log_string('error when loading the checkpoint.')
463 ######################################################################
465 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
468 for input in task.batches(split = 'train'):
469 token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
470 token_probas = token_count / token_count.sum()
471 entropy = -torch.xlogy(token_probas, token_probas).sum()
472 train_set_perplexity = math.exp(entropy)
474 for k in range(nb_epochs_finished, nb_epochs):
478 nb_train_samples, acc_train_loss = 0, 0.0
480 for input in task.batches(split = 'train'):
481 input = input.to(device)
482 output = model(input)
483 loss = F.cross_entropy(output.transpose(1, 2), input)
484 acc_train_loss += loss.item() * input.size(0)
485 nb_train_samples += input.size(0)
487 optimizer.zero_grad()
491 with torch.autograd.no_grad():
495 nb_test_samples, acc_test_loss = 0, 0.0
497 for input in task.batches(split = 'test'):
498 input = input.to(device)
499 output = model(input)
500 loss = F.cross_entropy(output.transpose(1, 2), input)
501 acc_test_loss += loss.item() * input.size(0)
502 nb_test_samples += input.size(0)
504 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
505 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
507 log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
509 task.produce_results(k, model)
512 'nb_epochs_finished': k + 1,
513 'model_state': model.state_dict(),
514 'optimizer_state': optimizer.state_dict()
517 torch.save(checkpoint, args.checkpoint_name)
519 ######################################################################