3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
21 parser = argparse.ArgumentParser(description = 'My own GPT.')
23 parser.add_argument('--log_filename',
24 type = str, default = 'train.log')
26 parser.add_argument('--seed',
27 type = int, default = 0)
29 parser.add_argument('--nb_epochs',
30 type = int, default = -1)
32 parser.add_argument('--batch_size',
33 type = int, default = 25)
35 parser.add_argument('--data',
36 type = str, default = 'wiki103')
38 parser.add_argument('--data_size',
39 type = int, default = -1)
41 parser.add_argument('--optim',
42 type = str, default = 'adam')
44 parser.add_argument('--learning_rate',
45 type = float, default = 1e-4)
47 parser.add_argument('--dim_model',
48 type = int, default = 512)
50 parser.add_argument('--dim_keys',
51 type = int, default = 64)
53 parser.add_argument('--dim_hidden',
54 type = int, default = 2048)
56 parser.add_argument('--nb_heads',
57 type = int, default = 8)
59 parser.add_argument('--nb_blocks',
60 type = int, default = 12)
62 parser.add_argument('--dropout',
63 type = float, default = 0.1)
65 parser.add_argument('--synthesis_sampling',
66 action='store_true', default = True)
68 parser.add_argument('--no_checkpoint',
69 action='store_true', default = False)
71 parser.add_argument('--checkpoint_name',
72 type = str, default = 'checkpoint.pth')
74 ##############################
77 parser.add_argument('--picoclvr_nb_colors',
78 type = int, default = 5)
80 parser.add_argument('--picoclvr_height',
81 type = int, default = 12)
83 parser.add_argument('--picoclvr_width',
84 type = int, default = 16)
86 ######################################################################
88 args = parser.parse_args()
90 log_file = open(args.log_filename, 'w')
93 torch.manual_seed(args.seed)
95 ######################################################################
98 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
100 if log_file is not None:
101 log_file.write(t + s + '\n')
108 log_string(f'args.{n} {getattr(args, n)}')
110 ######################################################################
114 nb_samples, nb_tokens_to_generate, primer = None,
115 device = torch.device('cpu')
117 results = torch.zeros(
118 nb_samples, nb_tokens_to_generate,
119 dtype = torch.int64, device = device
125 first = primer.size(1)
126 results = torch.cat((primer, results), 1)
128 for input in results.split(batch_size):
129 for s in range(first, input.size(1)):
130 output = model(input)
131 logits = output[:, s]
132 if args.synthesis_sampling:
133 dist = torch.distributions.categorical.Categorical(logits = logits)
134 t_next = dist.sample()
136 t_next = logits.argmax(1)
141 ######################################################################
144 def batches(self, split = 'train'):
147 def vocabulary_size(self):
150 def produce_results(self, n_epoch, model):
153 ######################################################################
157 class TaskPicoCLVR(Task):
159 # Make a tensor from a list of strings
160 def tensorize(self, descr):
161 token_descr = [ s.strip().split(' ') for s in descr ]
162 l = max([ len(s) for s in token_descr ])
163 #token_descr = [ [ '<nul>' ] * (l - len(s)) + s for s in token_descr ]
164 token_descr = [ s + [ '<nul>' ] * (l - len(s)) for s in token_descr ]
165 id_descr = [ [ self.token2id[u] for u in s ] for s in token_descr ]
166 return torch.tensor(id_descr, device = self.device)
168 def __init__(self, batch_size,
169 height, width, nb_colors = 5,
170 device = torch.device('cpu')):
172 def generate_descr(nb):
173 return picoclvr.generate(
175 height = self.height, width = self.width,
176 nb_colors = nb_colors
181 self.batch_size = batch_size
183 nb = args.data_size if args.data_size > 0 else 250000
185 self.train_descr = generate_descr((nb * 4) // 5)
186 self.test_descr = generate_descr((nb * 1) // 5)
188 # Build the tokenizer
190 for d in [ self.train_descr, self.test_descr ]:
192 for t in s.strip().split(' '): tokens.add(t)
193 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
194 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
196 # Tokenize the train and test sets
197 self.train_input = self.tensorize(self.train_descr)
198 self.test_input = self.tensorize(self.test_descr)
200 def batches(self, split = 'train'):
201 assert split in { 'train', 'test' }
202 input = self.train_input if split == 'train' else self.test_input
203 for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'):
206 def vocabulary_size(self):
207 return len(self.token2id)
209 def produce_results(self, n_epoch, model):
210 nb_tokens = self.height * self.width + 3
214 for primer_descr in [
215 'red above green <sep> green top <sep> blue right of red <img>',
216 'there is red <sep> there is yellow <sep> there is blue <img>',
217 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
218 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
221 for k in range(nb_per_primer):
222 results = autoregression(
223 model, self.batch_size,
224 nb_samples = 1, nb_tokens_to_generate = nb_tokens,
225 primer = self.tensorize([ primer_descr ]),
228 r = ' '.join([ self.id2token[t.item()] for t in results.flatten() ])
229 result_descr.append(r)
232 picoclvr.descr2img(d, height = self.height, width = self.width)
233 for d in result_descr
236 img = torch.cat(img, 0)
237 image_name = f'result_picoclvr_{n_epoch:04d}.png'
238 torchvision.utils.save_image(
240 image_name, nrow = nb_per_primer, pad_value = 0.8
242 log_string(f'wrote {image_name}')
244 np = picoclvr.nb_properties(
246 height = self.height, width = self.width
249 nb_requested_properties, _, nb_missing_properties = zip(*np)
251 log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
253 ######################################################################
255 class TaskWiki103(Task):
257 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
258 device = torch.device('cpu')):
260 self.batch_size = batch_size
261 self.len_min = len_min
262 self.len_max = len_max
263 self.min_freq = min_freq
266 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
267 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
270 if args.data_size > 0:
271 train_iter = itertools.islice(train_iter, args.data_size)
274 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
275 yield self.tokenizer(l)
277 self.vocab = torchtext.vocab.build_vocab_from_iterator(
279 specials = [ '<unk>', '<nul>' ],
280 min_freq = self.min_freq
283 self.vocab.set_default_index(self.vocab[ '<unk>' ])
285 # makes a tensor from a list of list of tokens
286 def tensorize(self, s):
287 a = max(len(x) for x in s)
288 return torch.tensor([ self.vocab(x + [ '<nul>' ] * (a - len(x))) for x in s ])
290 def yield_batches(self, ds):
293 q = self.tokenizer(l)
294 if len(q) >= self.len_min and len(q) <= self.len_max:
296 if len(s) == self.batch_size:
297 yield self.tensorize(s)
301 yield self.tensorize(s)
303 def batches(self, split = 'train'):
304 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
307 if args.data_size > 0:
308 data_iter = itertools.islice(data_iter, args.data_size)
310 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
312 def vocabulary_size(self):
313 return len(self.vocab)
315 def produce_results(self, n_epoch, model):
317 file_name = f'result_wiki103_{n_epoch:04d}.txt'
319 with open(file_name, 'w') as outfile:
321 'the cat is hunting a',
322 'paris is the capital',
323 'cars are convenient',
324 'the difference between men and women is',
325 'the object was blue all over and green all over it was',
326 'cherries are red and lemons are',
327 'cherries are sweet and lemons are',
328 'two plus three equals',
331 t_primer = self.tokenizer(primer)
334 for j in range(nb_tokens):
336 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
337 input = F.pad(input, (0, 1)) # Add the next token, the one to predict
338 output = model(input)
339 logits = output[0, -1]
340 if args.synthesis_sampling:
341 dist = torch.distributions.categorical.Categorical(logits = logits)
342 t_next = dist.sample()
344 t_next = logits.argmax()
345 t_generated.append(self.vocab.lookup_token(t_next))
346 if t_generated[-1] == '<nul>': break
348 s = ' '.join(t_generated)
350 outfile.write(f'<{primer}> {s}\n')
352 log_string(f'wrote {file_name}')
354 ######################################################################
356 class TaskMNIST(Task):
358 def __init__(self, batch_size, device = torch.device('cpu')):
360 self.batch_size = batch_size
362 def batches(self, split = 'train'):
363 assert split in { 'train', 'test' }
364 data_set = torchvision.datasets.MNIST(
365 root = './data', train = (split == 'train'),
368 data_input = data_set.data.view(-1, 28 * 28).long()
369 if args.data_size >= 0:
370 data_input = data_input[:args.data_size]
371 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
374 def vocabulary_size(self):
377 def produce_results(self, n_epoch, model):
379 results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
380 image_name = f'result_mnist_{n_epoch:04d}.png'
381 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
382 image_name, nrow = 16, pad_value = 0.8)
383 log_string(f'wrote {image_name}')
385 ######################################################################
387 log_string(f'device {device}')
389 if args.data == 'wiki103':
390 nb_epochs_default = 10
391 task = TaskWiki103(batch_size = args.batch_size, device = device)
392 elif args.data == 'mnist':
393 nb_epochs_default = 25
394 task = TaskMNIST(batch_size = args.batch_size, device = device)
395 elif args.data == 'picoclvr':
396 nb_epochs_default = 10
397 task = TaskPicoCLVR(batch_size = args.batch_size,
398 height = args.picoclvr_height,
399 width = args.picoclvr_width,
400 nb_colors = args.picoclvr_nb_colors,
403 raise ValueError(f'Unknown dataset {args.data}.')
405 vocabulary_size = task.vocabulary_size()
407 log_string(f'vocabulary_size {vocabulary_size}')
409 ##############################
412 vocabulary_size = vocabulary_size,
413 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
414 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
419 nb_parameters = sum(p.numel() for p in model.parameters())
420 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
422 ######################################################################
424 if args.optim == 'sgd':
425 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
426 elif args.optim == 'adam':
427 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
428 elif args.optim == 'adamw':
429 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
431 raise ValueError(f'Unknown optimizer {args.optim}.')
433 ######################################################################
435 nb_epochs_finished = 0
437 if args.no_checkpoint:
438 log_string(f'not trying to load checkpoint.')
442 checkpoint = torch.load(args.checkpoint_name, map_location = device)
443 nb_epochs_finished = checkpoint['nb_epochs_finished']
444 model.load_state_dict(checkpoint['model_state'])
445 optimizer.load_state_dict(checkpoint['optimizer_state'])
446 log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
448 except FileNotFoundError:
449 log_string('starting from scratch.')
452 log_string('error when loading the checkpoint.')
455 ######################################################################
457 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
460 for input in task.batches(split = 'train'):
461 token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
462 token_probas = token_count / token_count.sum()
463 entropy = -torch.xlogy(token_probas, token_probas).sum()
464 train_set_perplexity = math.exp(entropy)
465 #log_string(f'train set perplexity {train_set_perplexity}')
467 for k in range(nb_epochs_finished, nb_epochs):
471 nb_train_samples, acc_train_loss = 0, 0.0
473 for input in task.batches(split = 'train'):
474 input = input.to(device)
475 output = model(input)
476 loss = F.cross_entropy(output.transpose(1, 2), input)
477 acc_train_loss += loss.item() * input.size(0)
478 nb_train_samples += input.size(0)
480 optimizer.zero_grad()
484 with torch.autograd.no_grad():
488 nb_test_samples, acc_test_loss = 0, 0.0
490 for input in task.batches(split = 'test'):
491 input = input.to(device)
492 output = model(input)
493 loss = F.cross_entropy(output.transpose(1, 2), input)
494 acc_test_loss += loss.item() * input.size(0)
495 nb_test_samples += input.size(0)
497 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
498 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
500 log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
502 task.produce_results(k, model)
505 'nb_epochs_finished': k + 1,
506 'model_state': model.state_dict(),
507 'optimizer_state': optimizer.state_dict()
510 torch.save(checkpoint, args.checkpoint_name)
512 ######################################################################