3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
21 parser = argparse.ArgumentParser(description = 'My own GPT.')
23 parser.add_argument('--log_filename',
24 type = str, default = 'train.log')
26 parser.add_argument('--seed',
27 type = int, default = 0)
29 parser.add_argument('--nb_epochs',
30 type = int, default = -1)
32 parser.add_argument('--batch_size',
33 type = int, default = 25)
35 parser.add_argument('--data',
36 type = str, default = 'wiki103')
38 parser.add_argument('--data_size',
39 type = int, default = -1)
41 parser.add_argument('--optim',
42 type = str, default = 'adam')
44 parser.add_argument('--learning_rate',
45 type = float, default = 1e-4)
47 parser.add_argument('--dim_model',
48 type = int, default = 512)
50 parser.add_argument('--dim_keys',
51 type = int, default = 64)
53 parser.add_argument('--dim_hidden',
54 type = int, default = 2048)
56 parser.add_argument('--nb_heads',
57 type = int, default = 8)
59 parser.add_argument('--nb_blocks',
60 type = int, default = 12)
62 parser.add_argument('--dropout',
63 type = float, default = 0.1)
65 parser.add_argument('--synthesis_sampling',
66 action='store_true', default = True)
68 parser.add_argument('--no_checkpoint',
69 action='store_true', default = False)
71 parser.add_argument('--checkpoint_name',
72 type = str, default = 'checkpoint.pth')
74 ##############################
77 parser.add_argument('--picoclvr_nb_colors',
78 type = int, default = 5)
80 parser.add_argument('--picoclvr_height',
81 type = int, default = 12)
83 parser.add_argument('--picoclvr_width',
84 type = int, default = 16)
86 ######################################################################
88 args = parser.parse_args()
90 log_file = open(args.log_filename, 'w')
93 torch.manual_seed(args.seed)
95 ######################################################################
98 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
100 if log_file is not None:
101 log_file.write(t + s + '\n')
108 log_string(f'args.{n} {getattr(args, n)}')
110 ######################################################################
114 nb_samples, nb_tokens_to_generate, primer = None,
115 device = torch.device('cpu')
117 results = torch.zeros(
118 nb_samples, nb_tokens_to_generate,
119 dtype = torch.int64, device = device
125 first = primer.size(1)
126 results = torch.cat((primer, results), 1)
128 for input in results.split(batch_size):
129 for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'):
130 output = model(input)
131 logits = output[:, s]
132 if args.synthesis_sampling:
133 dist = torch.distributions.categorical.Categorical(logits = logits)
134 t_next = dist.sample()
136 t_next = logits.argmax(1)
141 ######################################################################
144 def batches(self, split = 'train'):
147 def vocabulary_size(self):
150 def produce_results(self, n_epoch, model):
153 ######################################################################
157 class TaskPicoCLVR(Task):
159 def descr2tensor(self, descr):
160 t = [ [ self.token2id[u] for u in s ] for s in descr ]
161 return torch.tensor(t, device = self.device)
163 def __init__(self, batch_size,
164 height, width, nb_colors = 5,
165 device = torch.device('cpu')):
167 def generate_descr(nb):
168 descr = picoclvr.generate(
170 height = self.height, width = self.width,
171 nb_colors = nb_colors
174 descr = [ s.strip().split(' ') for s in descr ]
175 l = max([ len(s) for s in descr ])
176 #descr = [ [ '<unk>' ] * (l - len(s)) + s for s in descr ]
177 descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
183 self.batch_size = batch_size
185 nb = args.data_size if args.data_size > 0 else 250000
187 self.train_descr = generate_descr((nb * 4) // 5)
188 self.test_descr = generate_descr((nb * 1) // 5)
190 # Build the tokenizer
192 for d in [ self.train_descr, self.test_descr ]:
194 for t in s: tokens.add(t)
195 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
196 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
198 # Tokenize the train and test sets
199 self.train_input = descr2tensor(self.train_descr)
200 self.test_input = descr2tensor(self.test_descr)
202 def batches(self, split = 'train'):
203 assert split in { 'train', 'test' }
204 input = self.train_input if split == 'train' else self.test_input
205 for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'):
208 def vocabulary_size(self):
209 return len(self.token2id)
211 def generate(self, primer_descr, model, nb_tokens):
212 results = autoregression(
213 model, self.batch_size,
214 nb_samples = 1, nb_tokens = nb_tokens, primer = descr2tensor(primer_descr),
217 return ' '.join([ self.id2token[t.item()] for t in results.flatten() ])
219 def produce_results(self, n_epoch, model):
220 nb_tokens = self.height * self.width + 3
224 for primer_descr in [
225 'red above green <sep> green top <sep> blue right of red <img>',
226 'there is red <sep> there is yellow <sep> there is blue <img>',
227 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
228 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
231 for k in range(nb_per_primer):
232 result_descr.append(self.generate(primer_descr, model, nb_tokens))
234 img = [ picoclvr.descr2img(d, height = self.height, width = self.width)
235 for d in result_descr ]
236 img = torch.cat(img, 0)
237 image_name = f'result_picoclvr_{n_epoch:04d}.png'
238 torchvision.utils.save_image(
240 image_name, nrow = nb_per_primer, pad_value = 0.8
242 log_string(f'wrote {image_name}')
244 np = picoclvr.nb_properties(
246 height = self.height, width = self.width
249 nb_requested_properties, _, nb_missing_properties = zip(*np)
251 log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
253 ######################################################################
255 class TaskWiki103(Task):
257 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
258 device = torch.device('cpu')):
260 self.batch_size = batch_size
261 self.len_min = len_min
262 self.len_max = len_max
263 self.min_freq = min_freq
266 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
267 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
270 if args.data_size > 0:
271 train_iter = itertools.islice(train_iter, args.data_size)
274 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
275 yield self.tokenizer(l)
277 self.vocab = torchtext.vocab.build_vocab_from_iterator(
279 specials = [ '<unk>', '<non>' ],
280 min_freq = self.min_freq
283 self.vocab.set_default_index(self.vocab[ '<unk>' ])
285 def tensorize(self, s):
286 a = max(len(x) for x in s)
287 return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
289 def yield_batches(self, ds):
292 q = self.tokenizer(l)
293 if len(q) >= self.len_min and len(q) <= self.len_max:
295 if len(s) == self.batch_size:
296 yield self.tensorize(s)
300 yield self.tensorize(s)
302 def batches(self, split = 'train'):
303 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
306 if args.data_size > 0:
307 data_iter = itertools.islice(data_iter, args.data_size)
309 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
311 def vocabulary_size(self):
312 return len(self.vocab)
314 def produce_results(self, n_epoch, model):
316 file_name = f'result_wiki103_{n_epoch:04d}.txt'
318 with open(file_name, 'w') as outfile:
320 'the cat is hunting a',
321 'paris is the capital',
322 'cars are convenient',
323 'the difference between men and women is',
324 'the object was blue all over and green all over it was',
325 'cherries are red and lemons are',
326 'cherries are sweet and lemons are',
327 'two plus three equals',
330 t_primer = self.tokenizer(primer)
333 for j in range(nb_tokens):
335 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
336 input = F.pad(input, (0, 1)) # Add the next token, the one to predict
337 output = model(input)
338 logits = output[0, -1]
339 if args.synthesis_sampling:
340 dist = torch.distributions.categorical.Categorical(logits = logits)
341 t_next = dist.sample()
343 t_next = logits.argmax()
344 t_generated.append(self.vocab.lookup_token(t_next))
345 if t_generated[-1] == '<non>': break
347 s = ' '.join(t_generated)
349 outfile.write(f'<{primer}> {s}\n')
351 log_string(f'wrote {file_name}')
353 ######################################################################
355 class TaskMNIST(Task):
357 def __init__(self, batch_size, device = torch.device('cpu')):
359 self.batch_size = batch_size
361 def batches(self, split = 'train'):
362 assert split in { 'train', 'test' }
363 data_set = torchvision.datasets.MNIST(
364 root = './data', train = (split == 'train'),
367 data_input = data_set.data.view(-1, 28 * 28).long()
368 if args.data_size >= 0:
369 data_input = data_input[:args.data_size]
370 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
373 def vocabulary_size(self):
376 def produce_results(self, n_epoch, model):
378 results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
379 image_name = f'result_mnist_{n_epoch:04d}.png'
380 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
381 image_name, nrow = 16, pad_value = 0.8)
382 log_string(f'wrote {image_name}')
384 ######################################################################
386 log_string(f'device {device}')
388 if args.data == 'wiki103':
389 nb_epochs_default = 10
390 task = TaskWiki103(batch_size = args.batch_size, device = device)
391 elif args.data == 'mnist':
392 nb_epochs_default = 25
393 task = TaskMNIST(batch_size = args.batch_size, device = device)
394 elif args.data == 'picoclvr':
395 nb_epochs_default = 10
396 task = TaskPicoCLVR(batch_size = args.batch_size,
397 height = args.picoclvr_height,
398 width = args.picoclvr_width,
399 nb_colors = args.picoclvr_nb_colors,
402 raise ValueError(f'Unknown dataset {args.data}.')
404 vocabulary_size = task.vocabulary_size()
406 log_string(f'vocabulary_size {vocabulary_size}')
408 ##############################
411 vocabulary_size = vocabulary_size,
412 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
413 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
418 nb_parameters = sum(p.numel() for p in model.parameters())
419 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
421 ######################################################################
423 if args.optim == 'sgd':
424 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
425 elif args.optim == 'adam':
426 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
427 elif args.optim == 'adamw':
428 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
430 raise ValueError(f'Unknown optimizer {args.optim}.')
432 ######################################################################
434 nb_epochs_finished = 0
436 if args.no_checkpoint:
437 log_string(f'not trying to load checkpoint.')
441 checkpoint = torch.load(args.checkpoint_name, map_location = device)
442 nb_epochs_finished = checkpoint['nb_epochs_finished']
443 model.load_state_dict(checkpoint['model_state'])
444 optimizer.load_state_dict(checkpoint['optimizer_state'])
445 log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
447 except FileNotFoundError:
448 log_string('starting from scratch.')
451 log_string('error when loading the checkpoint.')
454 ######################################################################
456 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
459 for input in task.batches(split = 'train'):
460 token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
461 token_probas = token_count / token_count.sum()
462 entropy = -torch.xlogy(token_probas, token_probas).sum()
463 train_set_perplexity = math.exp(entropy)
464 #log_string(f'train set perplexity {train_set_perplexity}')
466 for k in range(nb_epochs_finished, nb_epochs):
470 nb_train_samples, acc_train_loss = 0, 0.0
472 for input in task.batches(split = 'train'):
473 input = input.to(device)
474 output = model(input)
475 loss = F.cross_entropy(output.transpose(1, 2), input)
476 acc_train_loss += loss.item() * input.size(0)
477 nb_train_samples += input.size(0)
479 optimizer.zero_grad()
483 with torch.autograd.no_grad():
487 nb_test_samples, acc_test_loss = 0, 0.0
489 for input in task.batches(split = 'test'):
490 input = input.to(device)
491 output = model(input)
492 loss = F.cross_entropy(output.transpose(1, 2), input)
493 acc_test_loss += loss.item() * input.size(0)
494 nb_test_samples += input.size(0)
496 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
497 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
499 log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
501 task.produce_results(k, model)
504 'nb_epochs_finished': k + 1,
505 'model_state': model.state_dict(),
506 'optimizer_state': optimizer.state_dict()
509 torch.save(checkpoint, args.checkpoint_name)
511 ######################################################################