3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
24 parser.add_argument('--log_filename',
25 type = str, default = 'train.log')
27 parser.add_argument('--download',
28 type = bool, default = False)
30 parser.add_argument('--seed',
31 type = int, default = 0)
33 parser.add_argument('--nb_epochs',
34 type = int, default = 100)
36 parser.add_argument('--batch_size',
37 type = int, default = 25)
39 parser.add_argument('--data',
40 type = str, default = 'wiki103')
42 parser.add_argument('--data_size',
43 type = int, default = -1)
45 parser.add_argument('--optim',
46 type = str, default = 'adam')
48 parser.add_argument('--learning_rate',
49 type = float, default = 1e-4)
51 parser.add_argument('--dim_model',
52 type = int, default = 512)
54 parser.add_argument('--dim_keys',
55 type = int, default = 64)
57 parser.add_argument('--dim_hidden',
58 type = int, default = 2048)
60 parser.add_argument('--nb_heads',
61 type = int, default = 8)
63 parser.add_argument('--nb_blocks',
64 type = int, default = 12)
66 parser.add_argument('--dropout',
67 type = float, default = 0.1)
69 parser.add_argument('--synthesis_sampling',
70 type = bool, default = True)
72 ######################################################################
74 args = parser.parse_args()
76 log_file = open(args.log_filename, 'w')
79 torch.manual_seed(args.seed)
81 ######################################################################
84 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
86 if log_file is not None:
87 log_file.write(t + s + '\n')
94 log_string(f'args.{n} {getattr(args, n)}')
96 ######################################################################
99 def batches(self, split = 'train'):
102 def vocabulary_size(self):
105 def produce_results(self, n_epoch, model, nb_tokens = 50):
108 ######################################################################
112 class TaskPicoCLVR(Task):
114 def __init__(self, batch_size,
115 height = 6, width = 8, many_colors = False,
116 device = torch.device('cpu')):
118 self.batch_size = batch_size
120 nb = args.data_size if args.data_size > 0 else 250000
122 descr = picoclvr.generate(
124 height = height, width = width,
125 many_colors = many_colors
128 descr = [ s.strip().split(' ') for s in descr ]
129 l = max([ len(s) for s in descr ])
130 descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
134 for t in s: tokens.add(t)
135 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
136 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
138 t = [ [ self.token2id[u] for u in s ] for s in descr ]
139 data_input = torch.tensor(t, device = self.device)
141 self.test_input = data_input[:nb // 5]
142 self.train_input = data_input[nb // 5:]
144 def batches(self, split = 'train'):
145 assert split in { 'train', 'test' }
147 for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
150 for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
153 def vocabulary_size(self):
154 return len(self.token2id)
156 def produce_results(self, n_epoch, model, nb_tokens = 50):
161 'red above green <sep> green top <sep> blue right of red <img>',
162 'there is red <sep> there is yellow <sep> there is blue <img>',
163 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
164 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
167 for k in range(nb_per_primer):
168 t_primer = primer.strip().split(' ')
171 for j in range(nb_tokens):
172 t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
173 input = torch.tensor(t, device = self.device)
174 output = model(input)
175 logits = output[0, -1]
176 if args.synthesis_sampling:
177 dist = torch.distributions.categorical.Categorical(logits = logits)
181 t_generated.append(self.id2token[t.item()])
183 descr = [ ' '.join(t_primer + t_generated) ]
184 img += [ picoclvr.descr2img(descr) ]
186 img = torch.cat(img, 0)
187 file_name = f'result_picoclvr_{n_epoch:04d}.png'
188 torchvision.utils.save_image(img / 255.,
189 file_name, nrow = nb_per_primer, pad_value = 0.8)
190 log_string(f'wrote {file_name}')
192 ######################################################################
194 class TaskWiki103(Task):
196 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
197 device = torch.device('cpu')):
199 self.batch_size = batch_size
200 self.len_min = len_min
201 self.len_max = len_max
202 self.min_freq = min_freq
205 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
206 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
209 if args.data_size > 0:
210 train_iter = itertools.islice(train_iter, args.data_size)
213 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
214 yield self.tokenizer(l)
216 self.vocab = torchtext.vocab.build_vocab_from_iterator(
218 specials = [ '<unk>', '<non>' ],
219 min_freq = self.min_freq
222 self.vocab.set_default_index(self.vocab[ '<unk>' ])
224 def tensorize(self, s):
225 a = max(len(x) for x in s)
226 return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
228 def yield_batches(self, ds):
231 q = self.tokenizer(l)
232 if len(q) >= self.len_min and len(q) <= self.len_max:
234 if len(s) == self.batch_size:
235 yield self.tensorize(s)
239 yield self.tensorize(s)
241 def batches(self, split = 'train'):
242 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
245 if args.data_size > 0:
246 data_iter = itertools.islice(data_iter, args.data_size)
248 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
250 def vocabulary_size(self):
251 return len(self.vocab)
253 def produce_results(self, n_epoch, model, nb_tokens = 50):
254 file_name = f'result_wiki103_{n_epoch:04d}.txt'
256 with open(file_name, 'w') as outfile:
258 'the cat is hunting a',
259 'paris is the capital',
260 'cars are convenient',
261 'the difference between men and women is',
262 'the object was blue all over and green all over it was',
263 'cherries are red and lemons are',
264 'cherries are sweet and lemons are',
265 'two plus three equals',
268 t_primer = self.tokenizer(primer)
271 for j in range(nb_tokens):
273 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
274 output = model(input)
275 logits = output[0, -1]
276 if args.synthesis_sampling:
277 dist = torch.distributions.categorical.Categorical(logits = logits)
281 t_generated.append(self.vocab.lookup_token(t))
282 if t_generated[-1] == '<non>': break
284 s = ' '.join(t_generated)
286 outfile.write(f'<{primer}> {s}\n')
288 log_string(f'wrote {file_name}')
290 ######################################################################
292 class TaskMNIST(Task):
294 def __init__(self, batch_size, device = torch.device('cpu')):
296 self.batch_size = batch_size
298 def batches(self, split = 'train'):
299 assert split in { 'train', 'test' }
300 data_set = torchvision.datasets.MNIST(
301 root = './data', train = (split == 'train'),
304 data_input = data_set.data.view(-1, 28 * 28).long()
305 if args.data_size >= 0:
306 data_input = data_input[:args.data_size]
307 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
310 def vocabulary_size(self):
313 def produce_results(self, n_epoch, model, nb_samples = 64):
314 results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
315 for input in results.split(self.batch_size):
316 for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
317 output = model(input)
318 logits = output[:, s]
319 if args.synthesis_sampling:
320 dist = torch.distributions.categorical.Categorical(logits = logits)
326 image_name = f'result_mnist_{n_epoch:04d}.png'
327 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
328 image_name, nrow = 16, pad_value = 0.8)
329 log_string(f'wrote {image_name}')
331 ######################################################################
333 def check_causality(model):
335 input = torch.rand(1, 5, dim_model).requires_grad_()
337 a = torch.zeros(output.size(1), input.size(1))
338 for k in range(output.size(1)):
339 for d in range(output.size(2)):
340 g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
341 a[k] += g.squeeze(0).pow(2).sum(1)
344 ######################################################################
346 log_string(f'device {device}')
348 if args.data == 'wiki103':
349 task = TaskWiki103(batch_size = args.batch_size, device = device)
350 elif args.data == 'mnist':
351 task = TaskMNIST(batch_size = args.batch_size, device = device)
352 elif args.data == 'picoclvr':
353 task = TaskPicoCLVR(batch_size = args.batch_size, device = device)
355 raise ValueError(f'Unknown dataset {args.data}.')
357 vocabulary_size = task.vocabulary_size()
359 log_string(f'vocabulary_size {vocabulary_size}')
361 ##############################
364 vocabulary_size = vocabulary_size,
365 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
366 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
369 nb_parameters = sum(p.numel() for p in model.parameters())
370 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
374 ######################################################################
376 if args.optim == 'sgd':
377 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
378 elif args.optim == 'adam':
379 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
380 elif args.optim == 'adamw':
381 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
383 raise ValueError(f'Unknown optimizer {args.optim}.')
385 for k in range(args.nb_epochs):
389 nb_train_samples, acc_train_loss = 0, 0.0
391 for input in task.batches(split = 'train'):
392 input = input.to(device)
393 output = model(input)
394 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
395 acc_train_loss += loss.item() * input.size(0)
396 nb_train_samples += input.size(0)
398 optimizer.zero_grad()
402 with torch.autograd.no_grad():
406 nb_test_samples, acc_test_loss = 0, 0.0
408 for input in task.batches(split = 'test'):
409 input = input.to(device)
410 output = model(input)
411 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
412 acc_test_loss += loss.item() * input.size(0)
413 nb_test_samples += input.size(0)
415 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
416 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
418 log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
420 task.produce_results(k, model)
422 ######################################################################