3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
21 parser = argparse.ArgumentParser(description = 'My own GPT.')
23 parser.add_argument('--log_filename',
24 type = str, default = 'train.log')
26 parser.add_argument('--seed',
27 type = int, default = 0)
29 parser.add_argument('--nb_epochs',
30 type = int, default = -1)
32 parser.add_argument('--batch_size',
33 type = int, default = 25)
35 parser.add_argument('--data',
36 type = str, default = 'wiki103')
38 parser.add_argument('--data_size',
39 type = int, default = -1)
41 parser.add_argument('--optim',
42 type = str, default = 'adam')
44 parser.add_argument('--learning_rate',
45 type = float, default = 1e-4)
47 parser.add_argument('--dim_model',
48 type = int, default = 512)
50 parser.add_argument('--dim_keys',
51 type = int, default = 64)
53 parser.add_argument('--dim_hidden',
54 type = int, default = 2048)
56 parser.add_argument('--nb_heads',
57 type = int, default = 8)
59 parser.add_argument('--nb_blocks',
60 type = int, default = 12)
62 parser.add_argument('--dropout',
63 type = float, default = 0.1)
65 parser.add_argument('--synthesis_sampling',
66 action='store_true', default = True)
68 parser.add_argument('--no_checkpoint',
69 action='store_true', default = False)
71 parser.add_argument('--checkpoint_name',
72 type = str, default = 'checkpoint.pth')
74 ##############################
77 parser.add_argument('--picoclvr_nb_colors',
78 type = int, default = 5)
80 parser.add_argument('--picoclvr_height',
81 type = int, default = 12)
83 parser.add_argument('--picoclvr_width',
84 type = int, default = 16)
86 ######################################################################
88 args = parser.parse_args()
90 log_file = open(args.log_filename, 'w')
93 torch.manual_seed(args.seed)
95 ######################################################################
98 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
100 if log_file is not None:
101 log_file.write(t + s + '\n')
108 log_string(f'args.{n} {getattr(args, n)}')
110 ######################################################################
114 nb_samples, nb_tokens_to_generate, primer = None,
115 device = torch.device('cpu')
117 results = torch.zeros(
118 nb_samples, nb_tokens_to_generate,
119 dtype = torch.int64, device = device
125 first = primer.size(1)
126 results = torch.cat((primer, results), 1)
128 for input in results.split(batch_size):
129 for s in range(first, input.size(1)):
130 output = model(input)
131 logits = output[:, s]
132 if args.synthesis_sampling:
133 dist = torch.distributions.categorical.Categorical(logits = logits)
134 t_next = dist.sample()
136 t_next = logits.argmax(1)
141 ######################################################################
144 def batches(self, split = 'train'):
147 def vocabulary_size(self):
150 def produce_results(self, n_epoch, model):
153 ######################################################################
157 class TaskPicoCLVR(Task):
159 def tensorize(self, descr):
160 descr = [ s.strip().split(' ') for s in descr ]
161 l = max([ len(s) for s in descr ])
162 #descr = [ [ '<nul>' ] * (l - len(s)) + s for s in descr ]
163 descr = [ s + [ '<nul>' ] * (l - len(s)) for s in descr ]
164 t = [ [ self.token2id[u] for u in s ] for s in descr ]
165 return torch.tensor(t, device = self.device)
167 def __init__(self, batch_size,
168 height, width, nb_colors = 5,
169 device = torch.device('cpu')):
171 def generate_descr(nb):
172 return picoclvr.generate(
174 height = self.height, width = self.width,
175 nb_colors = nb_colors
180 self.batch_size = batch_size
182 nb = args.data_size if args.data_size > 0 else 250000
184 self.train_descr = generate_descr((nb * 4) // 5)
185 self.test_descr = generate_descr((nb * 1) // 5)
187 # Build the tokenizer
189 for d in [ self.train_descr, self.test_descr ]:
191 for t in s.strip().split(' '): tokens.add(t)
192 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
193 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
195 # Tokenize the train and test sets
196 self.train_input = self.tensorize(self.train_descr)
197 self.test_input = self.tensorize(self.test_descr)
199 def batches(self, split = 'train'):
200 assert split in { 'train', 'test' }
201 input = self.train_input if split == 'train' else self.test_input
202 for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'):
205 def vocabulary_size(self):
206 return len(self.token2id)
208 def produce_results(self, n_epoch, model):
209 nb_tokens = self.height * self.width + 3
213 for primer_descr in [
214 'red above green <sep> green top <sep> blue right of red <img>',
215 'there is red <sep> there is yellow <sep> there is blue <img>',
216 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
217 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
220 for k in range(nb_per_primer):
221 results = autoregression(
222 model, self.batch_size,
223 nb_samples = 1, nb_tokens_to_generate = nb_tokens,
224 primer = self.tensorize([ primer_descr ]),
227 r = ' '.join([ self.id2token[t.item()] for t in results.flatten() ])
228 result_descr.append(r)
231 picoclvr.descr2img(d, height = self.height, width = self.width)
232 for d in result_descr
235 img = torch.cat(img, 0)
236 image_name = f'result_picoclvr_{n_epoch:04d}.png'
237 torchvision.utils.save_image(
239 image_name, nrow = nb_per_primer, pad_value = 0.8
241 log_string(f'wrote {image_name}')
243 np = picoclvr.nb_properties(
245 height = self.height, width = self.width
248 nb_requested_properties, _, nb_missing_properties = zip(*np)
250 log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
252 ######################################################################
254 class TaskWiki103(Task):
256 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
257 device = torch.device('cpu')):
259 self.batch_size = batch_size
260 self.len_min = len_min
261 self.len_max = len_max
262 self.min_freq = min_freq
265 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
266 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
269 if args.data_size > 0:
270 train_iter = itertools.islice(train_iter, args.data_size)
273 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
274 yield self.tokenizer(l)
276 self.vocab = torchtext.vocab.build_vocab_from_iterator(
278 specials = [ '<unk>', '<nul>' ],
279 min_freq = self.min_freq
282 self.vocab.set_default_index(self.vocab[ '<unk>' ])
284 def tensorize(self, s):
285 a = max(len(x) for x in s)
286 return torch.tensor([ self.vocab(x + [ '<nul>' ] * (a - len(x))) for x in s ])
288 def yield_batches(self, ds):
291 q = self.tokenizer(l)
292 if len(q) >= self.len_min and len(q) <= self.len_max:
294 if len(s) == self.batch_size:
295 yield self.tensorize(s)
299 yield self.tensorize(s)
301 def batches(self, split = 'train'):
302 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
305 if args.data_size > 0:
306 data_iter = itertools.islice(data_iter, args.data_size)
308 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
310 def vocabulary_size(self):
311 return len(self.vocab)
313 def produce_results(self, n_epoch, model):
315 file_name = f'result_wiki103_{n_epoch:04d}.txt'
317 with open(file_name, 'w') as outfile:
319 'the cat is hunting a',
320 'paris is the capital',
321 'cars are convenient',
322 'the difference between men and women is',
323 'the object was blue all over and green all over it was',
324 'cherries are red and lemons are',
325 'cherries are sweet and lemons are',
326 'two plus three equals',
329 t_primer = self.tokenizer(primer)
332 for j in range(nb_tokens):
334 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
335 input = F.pad(input, (0, 1)) # Add the next token, the one to predict
336 output = model(input)
337 logits = output[0, -1]
338 if args.synthesis_sampling:
339 dist = torch.distributions.categorical.Categorical(logits = logits)
340 t_next = dist.sample()
342 t_next = logits.argmax()
343 t_generated.append(self.vocab.lookup_token(t_next))
344 if t_generated[-1] == '<nul>': break
346 s = ' '.join(t_generated)
348 outfile.write(f'<{primer}> {s}\n')
350 log_string(f'wrote {file_name}')
352 ######################################################################
354 class TaskMNIST(Task):
356 def __init__(self, batch_size, device = torch.device('cpu')):
358 self.batch_size = batch_size
360 def batches(self, split = 'train'):
361 assert split in { 'train', 'test' }
362 data_set = torchvision.datasets.MNIST(
363 root = './data', train = (split == 'train'),
366 data_input = data_set.data.view(-1, 28 * 28).long()
367 if args.data_size >= 0:
368 data_input = data_input[:args.data_size]
369 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
372 def vocabulary_size(self):
375 def produce_results(self, n_epoch, model):
377 results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
378 image_name = f'result_mnist_{n_epoch:04d}.png'
379 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
380 image_name, nrow = 16, pad_value = 0.8)
381 log_string(f'wrote {image_name}')
383 ######################################################################
385 log_string(f'device {device}')
387 if args.data == 'wiki103':
388 nb_epochs_default = 10
389 task = TaskWiki103(batch_size = args.batch_size, device = device)
390 elif args.data == 'mnist':
391 nb_epochs_default = 25
392 task = TaskMNIST(batch_size = args.batch_size, device = device)
393 elif args.data == 'picoclvr':
394 nb_epochs_default = 10
395 task = TaskPicoCLVR(batch_size = args.batch_size,
396 height = args.picoclvr_height,
397 width = args.picoclvr_width,
398 nb_colors = args.picoclvr_nb_colors,
401 raise ValueError(f'Unknown dataset {args.data}.')
403 vocabulary_size = task.vocabulary_size()
405 log_string(f'vocabulary_size {vocabulary_size}')
407 ##############################
410 vocabulary_size = vocabulary_size,
411 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
412 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
417 nb_parameters = sum(p.numel() for p in model.parameters())
418 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
420 ######################################################################
422 if args.optim == 'sgd':
423 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
424 elif args.optim == 'adam':
425 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
426 elif args.optim == 'adamw':
427 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
429 raise ValueError(f'Unknown optimizer {args.optim}.')
431 ######################################################################
433 nb_epochs_finished = 0
435 if args.no_checkpoint:
436 log_string(f'not trying to load checkpoint.')
440 checkpoint = torch.load(args.checkpoint_name, map_location = device)
441 nb_epochs_finished = checkpoint['nb_epochs_finished']
442 model.load_state_dict(checkpoint['model_state'])
443 optimizer.load_state_dict(checkpoint['optimizer_state'])
444 log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
446 except FileNotFoundError:
447 log_string('starting from scratch.')
450 log_string('error when loading the checkpoint.')
453 ######################################################################
455 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
458 for input in task.batches(split = 'train'):
459 token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
460 token_probas = token_count / token_count.sum()
461 entropy = -torch.xlogy(token_probas, token_probas).sum()
462 train_set_perplexity = math.exp(entropy)
463 #log_string(f'train set perplexity {train_set_perplexity}')
465 for k in range(nb_epochs_finished, nb_epochs):
469 nb_train_samples, acc_train_loss = 0, 0.0
471 for input in task.batches(split = 'train'):
472 input = input.to(device)
473 output = model(input)
474 loss = F.cross_entropy(output.transpose(1, 2), input)
475 acc_train_loss += loss.item() * input.size(0)
476 nb_train_samples += input.size(0)
478 optimizer.zero_grad()
482 with torch.autograd.no_grad():
486 nb_test_samples, acc_test_loss = 0, 0.0
488 for input in task.batches(split = 'test'):
489 input = input.to(device)
490 output = model(input)
491 loss = F.cross_entropy(output.transpose(1, 2), input)
492 acc_test_loss += loss.item() * input.size(0)
493 nb_test_samples += input.size(0)
495 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
496 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
498 log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
500 task.produce_results(k, model)
503 'nb_epochs_finished': k + 1,
504 'model_state': model.state_dict(),
505 'optimizer_state': optimizer.state_dict()
508 torch.save(checkpoint, args.checkpoint_name)
510 ######################################################################