3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
24 parser.add_argument('--log_filename',
25 type = str, default = 'train.log')
27 parser.add_argument('--seed',
28 type = int, default = 0)
30 parser.add_argument('--nb_epochs',
31 type = int, default = -1)
33 parser.add_argument('--batch_size',
34 type = int, default = 25)
36 parser.add_argument('--data',
37 type = str, default = 'wiki103')
39 parser.add_argument('--data_size',
40 type = int, default = -1)
42 parser.add_argument('--optim',
43 type = str, default = 'adam')
45 parser.add_argument('--learning_rate',
46 type = float, default = 1e-4)
48 parser.add_argument('--dim_model',
49 type = int, default = 512)
51 parser.add_argument('--dim_keys',
52 type = int, default = 64)
54 parser.add_argument('--dim_hidden',
55 type = int, default = 2048)
57 parser.add_argument('--nb_heads',
58 type = int, default = 8)
60 parser.add_argument('--nb_blocks',
61 type = int, default = 12)
63 parser.add_argument('--dropout',
64 type = float, default = 0.1)
66 parser.add_argument('--synthesis_sampling',
67 action='store_true', default = True)
69 parser.add_argument('--no_checkpoint',
70 action='store_true', default = False)
72 parser.add_argument('--checkpoint_name',
73 type = str, default = 'checkpoint.pth')
75 ##############################
78 parser.add_argument('--picoclvr_nb_colors',
79 type = int, default = 5)
81 parser.add_argument('--picoclvr_height',
82 type = int, default = 12)
84 parser.add_argument('--picoclvr_width',
85 type = int, default = 16)
87 ######################################################################
89 args = parser.parse_args()
91 log_file = open(args.log_filename, 'w')
94 torch.manual_seed(args.seed)
96 ######################################################################
99 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
101 if log_file is not None:
102 log_file.write(t + s + '\n')
109 log_string(f'args.{n} {getattr(args, n)}')
111 ######################################################################
115 nb_samples, nb_tokens_to_generate, starting_input = None,
116 device = torch.device('cpu')
118 results = torch.zeros(
119 nb_samples, nb_tokens_to_generate,
120 dtype = torch.int64, device = device
123 if starting_input is None:
126 first = starting_input.size(1)
127 results = torch.cat((starting_input, results), 1)
129 for input in results.split(batch_size):
130 for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'):
131 output = model(input)
132 logits = output[:, s]
133 if args.synthesis_sampling:
134 dist = torch.distributions.categorical.Categorical(logits = logits)
135 t_next = dist.sample()
137 t_next = logits.argmax(1)
142 ######################################################################
145 def batches(self, split = 'train'):
148 def vocabulary_size(self):
151 def produce_results(self, n_epoch, model, nb_tokens = 50):
154 ######################################################################
158 class TaskPicoCLVR(Task):
160 def __init__(self, batch_size,
161 height, width, nb_colors = 5,
162 device = torch.device('cpu')):
164 def generate_descr(nb):
165 descr = picoclvr.generate(
167 height = self.height, width = self.width,
168 nb_colors = nb_colors
171 descr = [ s.strip().split(' ') for s in descr ]
172 l = max([ len(s) for s in descr ])
173 #descr = [ [ '<unk>' ] * (l - len(s)) + s for s in descr ]
174 descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
180 self.batch_size = batch_size
182 nb = args.data_size if args.data_size > 0 else 250000
184 self.train_descr = generate_descr((nb * 4) // 5)
185 self.test_descr = generate_descr((nb * 1) // 5)
187 # Build the tokenizer
189 for d in [ self.train_descr, self.test_descr ]:
191 for t in s: tokens.add(t)
192 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
193 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
195 # Tokenize the train and test sets
196 t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ]
197 self.train_input = torch.tensor(t, device = self.device)
198 t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ]
199 self.test_input = torch.tensor(t, device = self.device)
201 def batches(self, split = 'train'):
202 assert split in { 'train', 'test' }
204 for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
207 for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
210 def vocabulary_size(self):
211 return len(self.token2id)
213 def generate(self, primer, model, nb_tokens):
214 t_primer = primer.strip().split(' ')
217 for j in range(nb_tokens):
218 t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
219 input = torch.tensor(t, device = self.device)
220 input = F.pad(input, (0, 1)) # Add the next token, the one to predict
221 output = model(input)
222 logits = output[0, -1]
223 if args.synthesis_sampling:
224 dist = torch.distributions.categorical.Categorical(logits = logits)
225 t_next = dist.sample()
227 t_next = logits.argmax()
228 t_generated.append(self.id2token[t_next.item()])
230 return ' '.join(t_primer + t_generated)
232 def produce_results(self, n_epoch, model, nb_tokens = None):
233 if nb_tokens is None:
234 nb_tokens = self.height * self.width + 3
239 'red above green <sep> green top <sep> blue right of red <img>',
240 'there is red <sep> there is yellow <sep> there is blue <img>',
241 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
242 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
245 for k in range(nb_per_primer):
246 descr.append(self.generate(primer, model, nb_tokens))
248 img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ]
249 img = torch.cat(img, 0)
250 image_name = f'result_picoclvr_{n_epoch:04d}.png'
251 torchvision.utils.save_image(
253 image_name, nrow = nb_per_primer, pad_value = 0.8
255 log_string(f'wrote {image_name}')
257 np = picoclvr.nb_properties(
259 height = self.height, width = self.width
262 nb_requested_properties, _, nb_missing_properties = zip(*np)
264 log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(descr):.02f}')
266 ######################################################################
268 class TaskWiki103(Task):
270 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
271 device = torch.device('cpu')):
273 self.batch_size = batch_size
274 self.len_min = len_min
275 self.len_max = len_max
276 self.min_freq = min_freq
279 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
280 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
283 if args.data_size > 0:
284 train_iter = itertools.islice(train_iter, args.data_size)
287 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
288 yield self.tokenizer(l)
290 self.vocab = torchtext.vocab.build_vocab_from_iterator(
292 specials = [ '<unk>', '<non>' ],
293 min_freq = self.min_freq
296 self.vocab.set_default_index(self.vocab[ '<unk>' ])
298 def tensorize(self, s):
299 a = max(len(x) for x in s)
300 return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
302 def yield_batches(self, ds):
305 q = self.tokenizer(l)
306 if len(q) >= self.len_min and len(q) <= self.len_max:
308 if len(s) == self.batch_size:
309 yield self.tensorize(s)
313 yield self.tensorize(s)
315 def batches(self, split = 'train'):
316 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
319 if args.data_size > 0:
320 data_iter = itertools.islice(data_iter, args.data_size)
322 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
324 def vocabulary_size(self):
325 return len(self.vocab)
327 def produce_results(self, n_epoch, model, nb_tokens = 50):
328 file_name = f'result_wiki103_{n_epoch:04d}.txt'
330 with open(file_name, 'w') as outfile:
332 'the cat is hunting a',
333 'paris is the capital',
334 'cars are convenient',
335 'the difference between men and women is',
336 'the object was blue all over and green all over it was',
337 'cherries are red and lemons are',
338 'cherries are sweet and lemons are',
339 'two plus three equals',
342 t_primer = self.tokenizer(primer)
345 for j in range(nb_tokens):
347 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
348 input = F.pad(input, (0, 1)) # Add the next token, the one to predict
349 output = model(input)
350 logits = output[0, -1]
351 if args.synthesis_sampling:
352 dist = torch.distributions.categorical.Categorical(logits = logits)
353 t_next = dist.sample()
355 t_next = logits.argmax()
356 t_generated.append(self.vocab.lookup_token(t_next))
357 if t_generated[-1] == '<non>': break
359 s = ' '.join(t_generated)
361 outfile.write(f'<{primer}> {s}\n')
363 log_string(f'wrote {file_name}')
365 ######################################################################
367 class TaskMNIST(Task):
369 def __init__(self, batch_size, device = torch.device('cpu')):
371 self.batch_size = batch_size
373 def batches(self, split = 'train'):
374 assert split in { 'train', 'test' }
375 data_set = torchvision.datasets.MNIST(
376 root = './data', train = (split == 'train'),
379 data_input = data_set.data.view(-1, 28 * 28).long()
380 if args.data_size >= 0:
381 data_input = data_input[:args.data_size]
382 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
385 def vocabulary_size(self):
388 def produce_results(self, n_epoch, model, nb_samples = 64):
389 results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
390 image_name = f'result_mnist_{n_epoch:04d}.png'
391 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
392 image_name, nrow = 16, pad_value = 0.8)
393 log_string(f'wrote {image_name}')
395 ######################################################################
397 log_string(f'device {device}')
399 if args.data == 'wiki103':
400 nb_epochs_default = 10
401 task = TaskWiki103(batch_size = args.batch_size, device = device)
402 elif args.data == 'mnist':
403 nb_epochs_default = 25
404 task = TaskMNIST(batch_size = args.batch_size, device = device)
405 elif args.data == 'picoclvr':
406 nb_epochs_default = 10
407 task = TaskPicoCLVR(batch_size = args.batch_size,
408 height = args.picoclvr_height,
409 width = args.picoclvr_width,
410 nb_colors = args.picoclvr_nb_colors,
413 raise ValueError(f'Unknown dataset {args.data}.')
415 vocabulary_size = task.vocabulary_size()
417 log_string(f'vocabulary_size {vocabulary_size}')
419 ##############################
422 vocabulary_size = vocabulary_size,
423 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
424 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
429 nb_parameters = sum(p.numel() for p in model.parameters())
430 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
432 ######################################################################
434 if args.optim == 'sgd':
435 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
436 elif args.optim == 'adam':
437 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
438 elif args.optim == 'adamw':
439 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
441 raise ValueError(f'Unknown optimizer {args.optim}.')
443 ######################################################################
445 nb_epochs_finished = 0
447 if args.no_checkpoint:
448 log_string(f'not trying to load checkpoint.')
452 checkpoint = torch.load(args.checkpoint_name, map_location = device)
453 nb_epochs_finished = checkpoint['nb_epochs_finished']
454 model.load_state_dict(checkpoint['model_state'])
455 optimizer.load_state_dict(checkpoint['optimizer_state'])
456 log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
458 except FileNotFoundError:
459 log_string('starting from scratch.')
462 log_string('error when loading the checkpoint.')
465 ######################################################################
467 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
470 for input in task.batches(split = 'train'):
471 token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
472 token_probas = token_count / token_count.sum()
473 h = -torch.xlogy(token_probas, token_probas).sum()
474 train_set_perplexity = math.exp(h)
475 log_string(f'train set perplexity {train_set_perplexity}')
477 for k in range(nb_epochs_finished, nb_epochs):
481 nb_train_samples, acc_train_loss = 0, 0.0
483 for input in task.batches(split = 'train'):
484 input = input.to(device)
485 output = model(input)
486 loss = F.cross_entropy(output.transpose(1, 2), input)
487 acc_train_loss += loss.item() * input.size(0)
488 nb_train_samples += input.size(0)
490 optimizer.zero_grad()
494 with torch.autograd.no_grad():
498 nb_test_samples, acc_test_loss = 0, 0.0
500 for input in task.batches(split = 'test'):
501 input = input.to(device)
502 output = model(input)
503 loss = F.cross_entropy(output.transpose(1, 2), input)
504 acc_test_loss += loss.item() * input.size(0)
505 nb_test_samples += input.size(0)
507 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
508 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
510 log_string(f'perplexity {k} train {train_perplexity} test {test_perplexity}')
512 task.produce_results(k, model)
515 'nb_epochs_finished': k + 1,
516 'model_state': model.state_dict(),
517 'optimizer_state': optimizer.state_dict()
520 torch.save(checkpoint, args.checkpoint_name)
522 ######################################################################