3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
21 parser = argparse.ArgumentParser(description = 'My own GPT.')
23 parser.add_argument('--log_filename',
24 type = str, default = 'train.log')
26 parser.add_argument('--seed',
27 type = int, default = 0)
29 parser.add_argument('--nb_epochs',
30 type = int, default = -1)
32 parser.add_argument('--batch_size',
33 type = int, default = 25)
35 parser.add_argument('--data',
36 type = str, default = 'wiki103')
38 parser.add_argument('--data_size',
39 type = int, default = -1)
41 parser.add_argument('--optim',
42 type = str, default = 'adam')
44 parser.add_argument('--learning_rate',
45 type = float, default = 1e-3)
47 parser.add_argument('--learning_rate_end',
48 type = float, default = 1e-6)
50 parser.add_argument('--dim_model',
51 type = int, default = 512)
53 parser.add_argument('--dim_keys',
54 type = int, default = 64)
56 parser.add_argument('--dim_hidden',
57 type = int, default = 2048)
59 parser.add_argument('--nb_heads',
60 type = int, default = 8)
62 parser.add_argument('--nb_blocks',
63 type = int, default = 12)
65 parser.add_argument('--dropout',
66 type = float, default = 0.1)
68 parser.add_argument('--deterministic_synthesis',
69 action='store_true', default = False)
71 parser.add_argument('--no_checkpoint',
72 action='store_true', default = False)
74 parser.add_argument('--checkpoint_name',
75 type = str, default = 'checkpoint.pth')
77 ##############################
80 parser.add_argument('--picoclvr_nb_colors',
81 type = int, default = 5)
83 parser.add_argument('--picoclvr_height',
84 type = int, default = 12)
86 parser.add_argument('--picoclvr_width',
87 type = int, default = 16)
89 ######################################################################
91 args = parser.parse_args()
93 log_file = open(args.log_filename, 'w')
96 torch.manual_seed(args.seed)
98 ######################################################################
101 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
103 if log_file is not None:
104 log_file.write(t + s + '\n')
111 log_string(f'args.{n} {getattr(args, n)}')
113 ######################################################################
117 nb_samples, nb_tokens_to_generate, primer = None,
118 device = torch.device('cpu')
120 results = torch.zeros(
121 nb_samples, nb_tokens_to_generate,
122 dtype = torch.int64, device = device
128 first = primer.size(1)
129 results = torch.cat((primer, results), 1)
131 for input in results.split(batch_size):
132 for s in range(first, input.size(1)):
133 output = model(input)
134 logits = output[:, s]
135 if args.deterministic_synthesis:
136 t_next = logits.argmax(1)
138 dist = torch.distributions.categorical.Categorical(logits = logits)
139 t_next = dist.sample()
144 ######################################################################
147 def batches(self, split = 'train'):
150 def vocabulary_size(self):
153 def produce_results(self, n_epoch, model):
156 ######################################################################
160 class TaskPicoCLVR(Task):
162 # Make a tensor from a list of strings
163 def tensorize(self, descr):
164 token_descr = [ s.strip().split(' ') for s in descr ]
165 l = max([ len(s) for s in token_descr ])
166 #token_descr = [ [ '<nul>' ] * (l - len(s)) + s for s in token_descr ]
167 token_descr = [ s + [ '<nul>' ] * (l - len(s)) for s in token_descr ]
168 id_descr = [ [ self.token2id[u] for u in s ] for s in token_descr ]
169 return torch.tensor(id_descr, device = self.device)
171 def trim(self, x, token = '<nul>'):
172 n = self.token2id[token]
173 i = (1 - (F.pad(x, (1, 1), value = n) == n).min(0).values.long()).cumsum(0)
174 a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
177 def __init__(self, batch_size,
178 height, width, nb_colors = 5,
179 device = torch.device('cpu')):
181 def generate_descr(nb):
182 return picoclvr.generate(
184 height = self.height, width = self.width,
185 nb_colors = nb_colors
190 self.batch_size = batch_size
192 nb = args.data_size if args.data_size > 0 else 250000
194 log_string(f'generating {nb} samples (can take some time)')
195 self.train_descr = generate_descr((nb * 4) // 5)
196 self.test_descr = generate_descr((nb * 1) // 5)
198 # Build the tokenizer
200 for d in [ self.train_descr, self.test_descr ]:
202 for t in s.strip().split(' '): tokens.add(t)
203 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
204 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
206 # Tokenize the train and test sets
207 self.train_input = self.tensorize(self.train_descr)
208 self.test_input = self.tensorize(self.test_descr)
210 def batches(self, split = 'train'):
211 assert split in { 'train', 'test' }
212 input = self.train_input if split == 'train' else self.test_input
213 for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'):
214 yield self.trim(batch)
216 def vocabulary_size(self):
217 return len(self.token2id)
219 def produce_results(self, n_epoch, model):
220 nb_tokens_to_generate = self.height * self.width + 3
224 for primer_descr in [
225 'red above green <sep> green top <sep> blue right of red <img>',
226 'there is red <sep> there is yellow <sep> there is blue <img>',
227 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
228 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
231 results = autoregression(
234 nb_samples = nb_per_primer,
235 nb_tokens_to_generate = nb_tokens_to_generate,
236 primer = self.tensorize([ primer_descr ]).expand(nb_per_primer, -1),
240 l = [ ' '.join([ self.id2token[t.item()] for t in r ]) for r in results ]
243 np = picoclvr.nb_properties(
245 height = self.height, width = self.width
248 nb_requested_properties, _, nb_missing_properties = zip(*np)
250 log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
253 picoclvr.descr2img(d, height = self.height, width = self.width)
254 for d in result_descr
257 img = torch.cat(img, 0)
258 image_name = f'result_picoclvr_{n_epoch:04d}.png'
259 torchvision.utils.save_image(
261 image_name, nrow = nb_per_primer, pad_value = 0.8
263 log_string(f'wrote {image_name}')
265 ######################################################################
267 class TaskWiki103(Task):
269 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
270 device = torch.device('cpu')):
272 self.batch_size = batch_size
273 self.len_min = len_min
274 self.len_max = len_max
275 self.min_freq = min_freq
278 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
279 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
282 if args.data_size > 0:
283 train_iter = itertools.islice(train_iter, args.data_size)
286 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
287 yield self.tokenizer(l)
289 self.vocab = torchtext.vocab.build_vocab_from_iterator(
291 specials = [ '<unk>', '<nul>' ],
292 min_freq = self.min_freq
295 self.vocab.set_default_index(self.vocab[ '<unk>' ])
297 # makes a tensor from a list of list of tokens
298 def tensorize(self, s):
299 a = max(len(x) for x in s)
300 return torch.tensor([ self.vocab(x + [ '<nul>' ] * (a - len(x))) for x in s ])
302 def yield_batches(self, ds):
305 q = self.tokenizer(l)
306 if len(q) >= self.len_min and len(q) <= self.len_max:
308 if len(s) == self.batch_size:
309 yield self.tensorize(s)
313 yield self.tensorize(s)
315 def batches(self, split = 'train'):
316 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
319 if args.data_size > 0:
320 data_iter = itertools.islice(data_iter, args.data_size)
322 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
324 def vocabulary_size(self):
325 return len(self.vocab)
327 def produce_results(self, n_epoch, model):
329 file_name = f'result_wiki103_{n_epoch:04d}.txt'
331 with open(file_name, 'w') as outfile:
333 'the cat is hunting a',
334 'paris is the capital',
335 'cars are convenient',
336 'the difference between men and women is',
337 'the object was blue all over and green all over it was',
338 'cherries are red and lemons are',
339 'cherries are sweet and lemons are',
340 'two plus three equals',
343 t_primer = self.tokenizer(primer)
346 for j in range(nb_tokens):
348 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
349 input = F.pad(input, (0, 1)) # Add the next token, the one to predict
350 output = model(input)
351 logits = output[0, -1]
352 if args.deterministic_synthesis:
353 t_next = logits.argmax()
355 dist = torch.distributions.categorical.Categorical(logits = logits)
356 t_next = dist.sample()
357 t_generated.append(self.vocab.lookup_token(t_next))
358 if t_generated[-1] == '<nul>': break
360 s = ' '.join(t_generated)
362 outfile.write(f'<{primer}> {s}\n')
364 log_string(f'wrote {file_name}')
366 ######################################################################
368 class TaskMNIST(Task):
370 def __init__(self, batch_size, device = torch.device('cpu')):
372 self.batch_size = batch_size
374 def batches(self, split = 'train'):
375 assert split in { 'train', 'test' }
376 data_set = torchvision.datasets.MNIST(
377 root = './data', train = (split == 'train'),
380 data_input = data_set.data.view(-1, 28 * 28).long()
381 if args.data_size >= 0:
382 data_input = data_input[:args.data_size]
383 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
386 def vocabulary_size(self):
389 def produce_results(self, n_epoch, model):
391 results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
392 image_name = f'result_mnist_{n_epoch:04d}.png'
393 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
394 image_name, nrow = 16, pad_value = 0.8)
395 log_string(f'wrote {image_name}')
397 ######################################################################
399 log_string(f'device {device}')
401 if args.data == 'wiki103':
402 nb_epochs_default = 10
403 task = TaskWiki103(batch_size = args.batch_size, device = device)
404 elif args.data == 'mnist':
405 nb_epochs_default = 25
406 task = TaskMNIST(batch_size = args.batch_size, device = device)
407 elif args.data == 'picoclvr':
408 nb_epochs_default = 10
409 task = TaskPicoCLVR(batch_size = args.batch_size,
410 height = args.picoclvr_height,
411 width = args.picoclvr_width,
412 nb_colors = args.picoclvr_nb_colors,
415 raise ValueError(f'Unknown dataset {args.data}.')
417 vocabulary_size = task.vocabulary_size()
419 log_string(f'vocabulary_size {vocabulary_size}')
421 ##############################
424 vocabulary_size = vocabulary_size,
425 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
426 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
431 nb_parameters = sum(p.numel() for p in model.parameters())
432 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
434 ######################################################################
436 nb_epochs_finished = 0
438 if args.no_checkpoint:
439 log_string(f'not trying to load checkpoint.')
443 checkpoint = torch.load(args.checkpoint_name)
444 nb_epochs_finished = checkpoint['nb_epochs_finished']
445 model.load_state_dict(checkpoint['model_state'])
446 torch.set_rng_state(checkpoint['rng_state'])
447 if torch.cuda.is_available():
448 torch.cuda.set_rng_state(checkpoint['cuda_rng_state'])
449 log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
451 except FileNotFoundError:
452 log_string('starting from scratch.')
455 log_string('error when loading the checkpoint.')
458 ######################################################################
460 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
463 for input in task.batches(split = 'train'):
464 token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
465 token_probas = token_count / token_count.sum()
466 entropy = -torch.xlogy(token_probas, token_probas).sum()
467 train_set_perplexity = math.exp(entropy)
469 for n_epoch in range(nb_epochs_finished, nb_epochs):
471 if args.learning_rate_end < 0:
472 lr = args.learning_rate
474 u = n_epoch / (nb_epochs - 1)
475 lr = math.exp((1 - u) * math.log(args.learning_rate) +
476 u * math.log(args.learning_rate_end))
477 log_string(f'learning_rate {lr}')
479 if args.optim == 'sgd':
480 optimizer = torch.optim.SGD(model.parameters(), lr = lr)
481 elif args.optim == 'adam':
482 optimizer = torch.optim.Adam(model.parameters(), lr = lr)
483 elif args.optim == 'adamw':
484 optimizer = torch.optim.AdamW(model.parameters(), lr = lr)
486 raise ValueError(f'Unknown optimizer {args.optim}.')
490 nb_train_samples, acc_train_loss = 0, 0.0
492 for input in task.batches(split = 'train'):
493 input = input.to(device)
494 output = model(input)
495 loss = F.cross_entropy(output.transpose(1, 2), input)
496 acc_train_loss += loss.item() * input.size(0)
497 nb_train_samples += input.size(0)
499 optimizer.zero_grad()
503 with torch.autograd.no_grad():
507 nb_test_samples, acc_test_loss = 0, 0.0
509 for input in task.batches(split = 'test'):
510 input = input.to(device)
511 output = model(input)
512 loss = F.cross_entropy(output.transpose(1, 2), input)
513 acc_test_loss += loss.item() * input.size(0)
514 nb_test_samples += input.size(0)
516 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
517 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
519 log_string(f'perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
521 task.produce_results(n_epoch, model)
524 'nb_epochs_finished': n_epoch + 1,
525 'model_state': model.state_dict(),
526 'rng_state': torch.get_rng_state(),
529 if torch.cuda.is_available():
530 checkpoint['cuda_rng_state'] = torch.cuda.get_rng_state()
532 torch.save(checkpoint, args.checkpoint_name)
534 ######################################################################