3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
24 parser.add_argument('--log_filename',
25 type = str, default = 'train.log')
27 parser.add_argument('--seed',
28 type = int, default = 0)
30 parser.add_argument('--nb_epochs',
31 type = int, default = -1)
33 parser.add_argument('--batch_size',
34 type = int, default = 25)
36 parser.add_argument('--data',
37 type = str, default = 'wiki103')
39 parser.add_argument('--data_size',
40 type = int, default = -1)
42 parser.add_argument('--optim',
43 type = str, default = 'adam')
45 parser.add_argument('--learning_rate',
46 type = float, default = 1e-4)
48 parser.add_argument('--dim_model',
49 type = int, default = 512)
51 parser.add_argument('--dim_keys',
52 type = int, default = 64)
54 parser.add_argument('--dim_hidden',
55 type = int, default = 2048)
57 parser.add_argument('--nb_heads',
58 type = int, default = 8)
60 parser.add_argument('--nb_blocks',
61 type = int, default = 12)
63 parser.add_argument('--dropout',
64 type = float, default = 0.1)
66 parser.add_argument('--synthesis_sampling',
67 action='store_true', default = True)
69 parser.add_argument('--no_checkpoint',
70 action='store_true', default = False)
72 parser.add_argument('--checkpoint_name',
73 type = str, default = 'checkpoint.pth')
75 ##############################
78 parser.add_argument('--picoclvr_nb_colors',
79 type = int, default = 5)
81 parser.add_argument('--picoclvr_height',
82 type = int, default = 12)
84 parser.add_argument('--picoclvr_width',
85 type = int, default = 16)
87 ######################################################################
89 args = parser.parse_args()
91 log_file = open(args.log_filename, 'w')
94 torch.manual_seed(args.seed)
96 ######################################################################
99 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
101 if log_file is not None:
102 log_file.write(t + s + '\n')
109 log_string(f'args.{n} {getattr(args, n)}')
111 ######################################################################
115 nb_samples, nb_tokens_to_generate, primer = None,
116 device = torch.device('cpu')
118 results = torch.zeros(
119 nb_samples, nb_tokens_to_generate,
120 dtype = torch.int64, device = device
126 first = primer.size(1)
127 results = torch.cat((primer, results), 1)
129 for input in results.split(batch_size):
130 for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'):
131 output = model(input)
132 logits = output[:, s]
133 if args.synthesis_sampling:
134 dist = torch.distributions.categorical.Categorical(logits = logits)
135 t_next = dist.sample()
137 t_next = logits.argmax(1)
142 ######################################################################
145 def batches(self, split = 'train'):
148 def vocabulary_size(self):
151 def produce_results(self, n_epoch, model, nb_tokens = 50):
154 ######################################################################
158 class TaskPicoCLVR(Task):
160 def descr2tensor(self, descr):
161 t = [ [ self.token2id[u] for u in s ] for s in descr ]
162 return torch.tensor(t, device = self.device)
164 def __init__(self, batch_size,
165 height, width, nb_colors = 5,
166 device = torch.device('cpu')):
168 def generate_descr(nb):
169 descr = picoclvr.generate(
171 height = self.height, width = self.width,
172 nb_colors = nb_colors
175 descr = [ s.strip().split(' ') for s in descr ]
176 l = max([ len(s) for s in descr ])
177 #descr = [ [ '<unk>' ] * (l - len(s)) + s for s in descr ]
178 descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
184 self.batch_size = batch_size
186 nb = args.data_size if args.data_size > 0 else 250000
188 self.train_descr = generate_descr((nb * 4) // 5)
189 self.test_descr = generate_descr((nb * 1) // 5)
191 # Build the tokenizer
193 for d in [ self.train_descr, self.test_descr ]:
195 for t in s: tokens.add(t)
196 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
197 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
199 # Tokenize the train and test sets
200 self.train_input = descr2tensor(self.train_descr)
201 self.test_input = descr2tensor(self.test_descr)
203 def batches(self, split = 'train'):
204 assert split in { 'train', 'test' }
206 for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
209 for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
212 def vocabulary_size(self):
213 return len(self.token2id)
215 def generate(self, descr_primer, model, nb_tokens):
216 results = autoregression(
217 model, self.batch_size,
218 1, nb_tokens, primer = descr2tensor(descr_primer),
221 return ' '.join([ self.id2token[t.item()] for t in results.flatten() ])
223 def produce_results(self, n_epoch, model, nb_tokens = None):
224 if nb_tokens is None:
225 nb_tokens = self.height * self.width + 3
229 for descr_primer in [
230 'red above green <sep> green top <sep> blue right of red <img>',
231 'there is red <sep> there is yellow <sep> there is blue <img>',
232 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
233 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
236 for k in range(nb_per_primer):
237 result_descr.append(self.generate(descr_primer, model, nb_tokens))
239 img = [ picoclvr.descr2img(d, height = self.height, width = self.width)
240 for d in result_descr ]
241 img = torch.cat(img, 0)
242 image_name = f'result_picoclvr_{n_epoch:04d}.png'
243 torchvision.utils.save_image(
245 image_name, nrow = nb_per_primer, pad_value = 0.8
247 log_string(f'wrote {image_name}')
249 np = picoclvr.nb_properties(
251 height = self.height, width = self.width
254 nb_requested_properties, _, nb_missing_properties = zip(*np)
256 log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
258 ######################################################################
260 class TaskWiki103(Task):
262 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
263 device = torch.device('cpu')):
265 self.batch_size = batch_size
266 self.len_min = len_min
267 self.len_max = len_max
268 self.min_freq = min_freq
271 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
272 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
275 if args.data_size > 0:
276 train_iter = itertools.islice(train_iter, args.data_size)
279 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
280 yield self.tokenizer(l)
282 self.vocab = torchtext.vocab.build_vocab_from_iterator(
284 specials = [ '<unk>', '<non>' ],
285 min_freq = self.min_freq
288 self.vocab.set_default_index(self.vocab[ '<unk>' ])
290 def tensorize(self, s):
291 a = max(len(x) for x in s)
292 return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
294 def yield_batches(self, ds):
297 q = self.tokenizer(l)
298 if len(q) >= self.len_min and len(q) <= self.len_max:
300 if len(s) == self.batch_size:
301 yield self.tensorize(s)
305 yield self.tensorize(s)
307 def batches(self, split = 'train'):
308 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
311 if args.data_size > 0:
312 data_iter = itertools.islice(data_iter, args.data_size)
314 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
316 def vocabulary_size(self):
317 return len(self.vocab)
319 def produce_results(self, n_epoch, model, nb_tokens = 50):
320 file_name = f'result_wiki103_{n_epoch:04d}.txt'
322 with open(file_name, 'w') as outfile:
324 'the cat is hunting a',
325 'paris is the capital',
326 'cars are convenient',
327 'the difference between men and women is',
328 'the object was blue all over and green all over it was',
329 'cherries are red and lemons are',
330 'cherries are sweet and lemons are',
331 'two plus three equals',
334 t_primer = self.tokenizer(primer)
337 for j in range(nb_tokens):
339 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
340 input = F.pad(input, (0, 1)) # Add the next token, the one to predict
341 output = model(input)
342 logits = output[0, -1]
343 if args.synthesis_sampling:
344 dist = torch.distributions.categorical.Categorical(logits = logits)
345 t_next = dist.sample()
347 t_next = logits.argmax()
348 t_generated.append(self.vocab.lookup_token(t_next))
349 if t_generated[-1] == '<non>': break
351 s = ' '.join(t_generated)
353 outfile.write(f'<{primer}> {s}\n')
355 log_string(f'wrote {file_name}')
357 ######################################################################
359 class TaskMNIST(Task):
361 def __init__(self, batch_size, device = torch.device('cpu')):
363 self.batch_size = batch_size
365 def batches(self, split = 'train'):
366 assert split in { 'train', 'test' }
367 data_set = torchvision.datasets.MNIST(
368 root = './data', train = (split == 'train'),
371 data_input = data_set.data.view(-1, 28 * 28).long()
372 if args.data_size >= 0:
373 data_input = data_input[:args.data_size]
374 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
377 def vocabulary_size(self):
380 def produce_results(self, n_epoch, model, nb_samples = 64):
381 results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
382 image_name = f'result_mnist_{n_epoch:04d}.png'
383 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
384 image_name, nrow = 16, pad_value = 0.8)
385 log_string(f'wrote {image_name}')
387 ######################################################################
389 log_string(f'device {device}')
391 if args.data == 'wiki103':
392 nb_epochs_default = 10
393 task = TaskWiki103(batch_size = args.batch_size, device = device)
394 elif args.data == 'mnist':
395 nb_epochs_default = 25
396 task = TaskMNIST(batch_size = args.batch_size, device = device)
397 elif args.data == 'picoclvr':
398 nb_epochs_default = 10
399 task = TaskPicoCLVR(batch_size = args.batch_size,
400 height = args.picoclvr_height,
401 width = args.picoclvr_width,
402 nb_colors = args.picoclvr_nb_colors,
405 raise ValueError(f'Unknown dataset {args.data}.')
407 vocabulary_size = task.vocabulary_size()
409 log_string(f'vocabulary_size {vocabulary_size}')
411 ##############################
414 vocabulary_size = vocabulary_size,
415 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
416 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
421 nb_parameters = sum(p.numel() for p in model.parameters())
422 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
424 ######################################################################
426 if args.optim == 'sgd':
427 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
428 elif args.optim == 'adam':
429 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
430 elif args.optim == 'adamw':
431 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
433 raise ValueError(f'Unknown optimizer {args.optim}.')
435 ######################################################################
437 nb_epochs_finished = 0
439 if args.no_checkpoint:
440 log_string(f'not trying to load checkpoint.')
444 checkpoint = torch.load(args.checkpoint_name, map_location = device)
445 nb_epochs_finished = checkpoint['nb_epochs_finished']
446 model.load_state_dict(checkpoint['model_state'])
447 optimizer.load_state_dict(checkpoint['optimizer_state'])
448 log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
450 except FileNotFoundError:
451 log_string('starting from scratch.')
454 log_string('error when loading the checkpoint.')
457 ######################################################################
459 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
462 for input in task.batches(split = 'train'):
463 token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
464 token_probas = token_count / token_count.sum()
465 entropy = -torch.xlogy(token_probas, token_probas).sum()
466 train_set_perplexity = math.exp(entropy)
467 #log_string(f'train set perplexity {train_set_perplexity}')
469 for k in range(nb_epochs_finished, nb_epochs):
473 nb_train_samples, acc_train_loss = 0, 0.0
475 for input in task.batches(split = 'train'):
476 input = input.to(device)
477 output = model(input)
478 loss = F.cross_entropy(output.transpose(1, 2), input)
479 acc_train_loss += loss.item() * input.size(0)
480 nb_train_samples += input.size(0)
482 optimizer.zero_grad()
486 with torch.autograd.no_grad():
490 nb_test_samples, acc_test_loss = 0, 0.0
492 for input in task.batches(split = 'test'):
493 input = input.to(device)
494 output = model(input)
495 loss = F.cross_entropy(output.transpose(1, 2), input)
496 acc_test_loss += loss.item() * input.size(0)
497 nb_test_samples += input.size(0)
499 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
500 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
502 log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
504 task.produce_results(k, model)
507 'nb_epochs_finished': k + 1,
508 'model_state': model.state_dict(),
509 'optimizer_state': optimizer.state_dict()
512 torch.save(checkpoint, args.checkpoint_name)
514 ######################################################################