3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
14 ######################################################################
16 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18 ######################################################################
20 parser = argparse.ArgumentParser(description = 'My own GPT.')
22 parser.add_argument('--log_filename',
23 type = str, default = 'train.log')
25 parser.add_argument('--download',
26 type = bool, default = False)
28 parser.add_argument('--seed',
29 type = int, default = 0)
31 parser.add_argument('--nb_epochs',
32 type = int, default = 100)
34 parser.add_argument('--batch_size',
35 type = int, default = 25)
37 parser.add_argument('--data',
38 type = str, default = 'wiki103')
40 parser.add_argument('--data_size',
41 type = int, default = -1)
43 parser.add_argument('--optim',
44 type = str, default = 'adam')
46 parser.add_argument('--learning_rate',
47 type = float, default = 1e-4)
49 parser.add_argument('--dim_model',
50 type = int, default = 512)
52 parser.add_argument('--dim_keys',
53 type = int, default = 64)
55 parser.add_argument('--dim_hidden',
56 type = int, default = 2048)
58 parser.add_argument('--nb_heads',
59 type = int, default = 8)
61 parser.add_argument('--nb_blocks',
62 type = int, default = 12)
64 parser.add_argument('--dropout',
65 type = float, default = 0.1)
67 parser.add_argument('--synthesis_sampling',
68 type = bool, default = True)
70 ######################################################################
72 args = parser.parse_args()
74 log_file = open(args.log_filename, 'w')
77 torch.manual_seed(args.seed)
79 ######################################################################
82 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
84 if log_file is not None:
85 log_file.write(t + s + '\n')
92 log_string(f'args.{n} {getattr(args, n)}')
94 ##############################
96 class Residual(nn.Module):
97 def __init__(self, *f):
99 self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
101 def forward(self, x):
104 ##############################
106 class PositionalEncoding(nn.Module):
107 def __init__(self, len_max):
109 self.len_max = len_max
111 # From Vaswani et al 2018
112 # PE_{t,2i} = sin(t/(L^{2i/D}))
113 # PE_{t,2i+1} = cos(t/(L^{2i/D}))
114 def forward(self, x):
115 t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None]
116 j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :]
118 return x + torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)[None, :, :]
120 ##############################
122 class QKVAttention(nn.Module):
123 def __init__(self, dim_in, dim_qk, dim_v, nb_heads = 1, causal = False, attention_dropout = 0.0):
127 return nn.Parameter(torch.empty(*d).normal_(0, 1 / math.sqrt(d[-1])))
129 self.wq = randw(nb_heads, dim_qk, dim_in)
130 self.wk = randw(nb_heads, dim_qk, dim_in)
131 self.wv = randw(nb_heads, dim_v, dim_in)
133 self.attention_dropout = attention_dropout
135 def forward(self, x):
136 q = torch.einsum('ntc,hdc->nhtd', x, self.wq)
137 k = torch.einsum('ntc,hdc->nhtd', x, self.wk)
138 v = torch.einsum('ntc,hdc->nhtd', x, self.wv)
139 r = math.sqrt(q.size(3))
140 a = torch.einsum('nhtd,nhsd->nhts', q, k).div(r)
142 mask = torch.tril(q.new_ones(a.size(2), a.size(3)))[None, None, :, :] == 0
143 a = a.masked_fill(mask, float('-inf'))
144 a = a.softmax(dim = 3)
145 a = F.dropout(a, self.attention_dropout, self.training)
146 y = torch.einsum('nhts,nhsd->nhtd', a, v)
147 return y.permute(0, 2, 1, 3).flatten(2) # nhtd -> nt(hd)
149 ##############################
151 class MyGPT(nn.Module):
154 dim_model, dim_keys, dim_hidden,
155 nb_heads, nb_blocks, dropout = 0.):
159 assert dim_model % nb_heads == 0
161 self.embedding = nn.Sequential(
162 nn.Embedding(vocabulary_size, dim_model),
164 PositionalEncoding(len_max = 1e5),
169 for _ in range(nb_blocks):
172 nn.LayerNorm(dim_model),
175 dim_qk = dim_keys, dim_v = dim_model // nb_heads,
177 causal = True, attention_dropout = dropout
179 nn.Linear(in_features = dim_model, out_features = dim_model),
182 nn.LayerNorm(dim_model),
183 nn.Linear(in_features = dim_model, out_features = dim_hidden),
185 nn.Linear(in_features = dim_hidden, out_features = dim_model),
190 self.trunk = nn.Sequential(*trunk_blocks)
192 self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size)
194 def forward(self, x):
195 x = self.embedding(x)
200 ######################################################################
203 def batches(self, split = 'train'):
206 def vocabulary_size(self):
209 def produce_results(self, n_epoch, model, nb_tokens = 50):
212 ######################################################################
216 class TaskPicoCLVR(Task):
218 def __init__(self, batch_size, height = 6, width = 8, device = torch.device('cpu')):
219 self.batch_size = batch_size
221 nb = args.data_size if args.data_size > 0 else 250000
223 descr = picoclvr.generate(nb, height = height, width = width)
224 descr = [ s.strip().split(' ') for s in descr ]
225 l = max([ len(s) for s in descr ])
226 descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
230 for t in s: tokens.add(t)
231 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
232 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
234 t = [ [ self.token2id[u] for u in s ] for s in descr ]
235 data_input = torch.tensor(t, device = self.device)
237 self.test_input = data_input[:nb // 5]
238 self.train_input = data_input[nb // 5:]
240 def batches(self, split = 'train'):
241 assert split in { 'train', 'test' }
243 for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = 'epoch'):
246 for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = 'epoch'):
249 def vocabulary_size(self):
250 return len(self.token2id)
252 def produce_results(self, n_epoch, model, nb_tokens = 50):
256 'red above green <sep> green top <sep> blue right of red <img>',
257 'there is red <sep> there is yellow <sep> there is blue <img>',
258 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
259 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
262 for k in range(nb_per_primer):
263 t_primer = primer.strip().split(' ')
266 for j in range(nb_tokens):
267 t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
268 input = torch.tensor(t, device = self.device)
269 output = model(input)
270 logits = output[0, -1]
271 if args.synthesis_sampling:
272 dist = torch.distributions.categorical.Categorical(logits = logits)
276 t_generated.append(self.id2token[t.item()])
278 descr = [ ' '.join(t_primer + t_generated) ]
279 img += [ picoclvr.descr2img(descr) ]
281 img = torch.cat(img, 0)
282 file_name = f'result_picoclvr_{n_epoch:04d}.png'
283 torchvision.utils.save_image(img / 255.,
284 file_name, nrow = nb_per_primer, pad_value = 0.8)
285 log_string(f'wrote {file_name}')
287 ######################################################################
289 class TaskWiki103(Task):
291 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
292 device = torch.device('cpu')):
294 self.batch_size = batch_size
295 self.len_min = len_min
296 self.len_max = len_max
297 self.min_freq = min_freq
300 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
301 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
304 if args.data_size > 0:
305 train_iter = itertools.islice(train_iter, args.data_size)
308 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
309 yield self.tokenizer(l)
311 self.vocab = torchtext.vocab.build_vocab_from_iterator(
313 specials = [ '<unk>', '<non>' ],
314 min_freq = self.min_freq
317 self.vocab.set_default_index(self.vocab[ '<unk>' ])
319 def tensorize(self, s):
320 a = max(len(x) for x in s)
321 return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
323 def yield_batches(self, ds):
326 q = self.tokenizer(l)
327 if len(q) >= self.len_min and len(q) <= self.len_max:
329 if len(s) == self.batch_size:
330 yield self.tensorize(s)
334 yield self.tensorize(s)
336 def batches(self, split = 'train'):
337 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
340 if args.data_size > 0:
341 data_iter = itertools.islice(data_iter, args.data_size)
343 return self.yield_batches(tqdm.tqdm(data_iter, desc = 'epoch'))
345 def vocabulary_size(self):
346 return len(self.vocab)
348 def produce_results(self, n_epoch, model, nb_tokens = 50):
349 file_name = f'result_wiki103_{n_epoch:04d}.txt'
351 with open(file_name, 'w') as outfile:
353 'the cat is hunting a',
354 'paris is the capital',
355 'cars are convenient',
356 'the difference between men and women is',
357 'the object was blue all over and green all over it was',
358 'cherries are red and lemons are',
359 'cherries are sweet and lemons are',
360 'two plus three equals',
363 t_primer = self.tokenizer(primer)
366 for j in range(nb_tokens):
368 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
369 output = model(input)
370 logits = output[0, -1]
371 if args.synthesis_sampling:
372 dist = torch.distributions.categorical.Categorical(logits = logits)
376 t_generated.append(self.vocab.lookup_token(t))
377 if t_generated[-1] == '<non>': break
379 s = ' '.join(t_generated)
381 outfile.write(f'<{primer}> {s}\n')
383 log_string(f'wrote {file_name}')
385 ######################################################################
387 class TaskMNIST(Task):
389 def __init__(self, batch_size, device = torch.device('cpu')):
391 self.batch_size = batch_size
393 def batches(self, split = 'train'):
394 assert split in { 'train', 'test' }
395 data_set = torchvision.datasets.MNIST(
396 root = './data', train = (split == 'train'),
399 data_input = data_set.data.view(-1, 28 * 28).long()
400 if args.data_size >= 0:
401 data_input = data_input[:args.data_size]
402 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = 'epoch'):
405 def vocabulary_size(self):
408 def produce_results(self, n_epoch, model, nb_samples = 64):
409 results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
410 for input in results.split(self.batch_size):
411 for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
412 output = model(input)
413 logits = output[:, s]
414 if args.synthesis_sampling:
415 dist = torch.distributions.categorical.Categorical(logits = logits)
421 image_name = f'result_mnist_{n_epoch:04d}.png'
422 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
423 image_name, nrow = 16, pad_value = 0.8)
424 log_string(f'wrote {image_name}')
426 ######################################################################
428 def check_causality(model):
430 input = torch.rand(1, 5, dim_model).requires_grad_()
432 a = torch.zeros(output.size(1), input.size(1))
433 for k in range(output.size(1)):
434 for d in range(output.size(2)):
435 g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
436 a[k] += g.squeeze(0).pow(2).sum(1)
439 ######################################################################
441 log_string(f'device {device}')
443 if args.data == 'wiki103':
444 task = TaskWiki103(batch_size = args.batch_size, device = device)
445 elif args.data == 'mnist':
446 task = TaskMNIST(batch_size = args.batch_size, device = device)
447 elif args.data == 'picoclvr':
448 task = TaskPicoCLVR(batch_size = args.batch_size, device = device)
450 raise ValueError(f'Unknown dataset {args.data}.')
452 vocabulary_size = task.vocabulary_size()
454 log_string(f'vocabulary_size {vocabulary_size}')
456 ##############################
459 vocabulary_size = vocabulary_size,
460 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
461 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
464 nb_parameters = sum(p.numel() for p in model.parameters())
465 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
469 ######################################################################
471 if args.optim == 'sgd':
472 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
473 elif args.optim == 'adam':
474 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
475 elif args.optim == 'adamw':
476 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
478 raise ValueError(f'Unknown optimizer {args.optim}.')
480 for k in range(args.nb_epochs):
484 nb_train_samples, acc_train_loss = 0, 0.0
486 for input in task.batches(split = 'train'):
487 input = input.to(device)
488 output = model(input)
489 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
490 acc_train_loss += loss.item() * input.size(0)
491 nb_train_samples += input.size(0)
493 optimizer.zero_grad()
497 with torch.autograd.no_grad():
501 nb_test_samples, acc_test_loss = 0, 0.0
503 for input in task.batches(split = 'test'):
504 input = input.to(device)
505 output = model(input)
506 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
507 acc_test_loss += loss.item() * input.size(0)
508 nb_test_samples += input.size(0)
510 log_string(f'perplexity {k+1} train {math.exp(min(100, acc_train_loss/nb_train_samples))} test {math.exp(min(100, acc_test_loss/nb_test_samples))}')
512 task.produce_results(k, model)
514 ######################################################################