# Written by Francois Fleuret <francois@fleuret.org>
-import math, sys, argparse, time, tqdm, itertools
+import math
+
+import torch
-import torch, torchtext, torchvision
from torch import nn
from torch.nn import functional as F
-######################################################################
-
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
-######################################################################
-
-parser = argparse.ArgumentParser(description = 'My own GPT.')
-
-parser.add_argument('--log_filename',
- type = str, default = 'train.log')
-
-parser.add_argument('--download',
- type = bool, default = False)
-
-parser.add_argument('--seed',
- type = int, default = 0)
-
-parser.add_argument('--nb_epochs',
- type = int, default = 100)
-
-parser.add_argument('--batch_size',
- type = int, default = 25)
-
-parser.add_argument('--data',
- type = str, default = 'wiki103')
-
-parser.add_argument('--data_size',
- type = int, default = -1)
-
-parser.add_argument('--optim',
- type = str, default = 'adam')
-
-parser.add_argument('--learning_rate',
- type = float, default = 1e-4)
-
-parser.add_argument('--dim_model',
- type = int, default = 512)
-
-parser.add_argument('--dim_keys',
- type = int, default = 64)
-
-parser.add_argument('--dim_hidden',
- type = int, default = 2048)
-
-parser.add_argument('--nb_heads',
- type = int, default = 8)
-
-parser.add_argument('--nb_blocks',
- type = int, default = 12)
-
-parser.add_argument('--dropout',
- type = float, default = 0.1)
-
-parser.add_argument('--synthesis_sampling',
- type = bool, default = True)
-
-######################################################################
-
-args = parser.parse_args()
-
-log_file = open(args.log_filename, 'w')
-
-if args.seed >= 0:
- torch.manual_seed(args.seed)
-
-######################################################################
-
-def log_string(s):
- t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
-
- if log_file is not None:
- log_file.write(t + s + '\n')
- log_file.flush()
-
- print(t + s)
- sys.stdout.flush()
-
-for n in vars(args):
- log_string(f'args.{n} {getattr(args, n)}')
-
##############################
class Residual(nn.Module):
def randw(*d):
return nn.Parameter(torch.empty(*d).normal_(0, 1 / math.sqrt(d[-1])))
- self.wq = randw(nb_heads, dim_qk, dim_in)
- self.wk = randw(nb_heads, dim_qk, dim_in)
- self.wv = randw(nb_heads, dim_v, dim_in)
+ self.w_q = randw(nb_heads, dim_qk, dim_in)
+ self.w_k = randw(nb_heads, dim_qk, dim_in)
+ self.w_v = randw(nb_heads, dim_v, dim_in)
self.causal = causal
self.attention_dropout = attention_dropout
def forward(self, x):
- q = torch.einsum('ntc,hdc->nhtd', x, self.wq)
- k = torch.einsum('ntc,hdc->nhtd', x, self.wk)
- v = torch.einsum('ntc,hdc->nhtd', x, self.wv)
+ q = torch.einsum('ntc,hdc->nhtd', x, self.w_q)
+ k = torch.einsum('ntc,hdc->nhtd', x, self.w_k)
+ v = torch.einsum('ntc,hdc->nhtd', x, self.w_v)
r = math.sqrt(q.size(3))
a = torch.einsum('nhtd,nhsd->nhts', q, k).div(r)
if self.causal:
######################################################################
-class Task:
- def batches(self, split = 'train'):
- pass
-
- def vocabulary_size(self):
- pass
-
- def produce_results(self, n_epoch, model, nb_tokens = 50):
- pass
-
-######################################################################
-
-import picoclvr
-
-class TaskPicoCLVR(Task):
-
- def __init__(self, batch_size, height = 6, width = 8, device = torch.device('cpu')):
- self.batch_size = batch_size
- self.device = device
- nb = args.data_size if args.data_size > 0 else 250000
-
- descr = picoclvr.generate(nb, height = height, width = width)
- descr = [ s.strip().split(' ') for s in descr ]
- l = max([ len(s) for s in descr ])
- descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
-
- tokens = set()
- for s in descr:
- for t in s: tokens.add(t)
- self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
- self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
-
- t = [ [ self.token2id[u] for u in s ] for s in descr ]
- data_input = torch.tensor(t, device = self.device)
-
- self.test_input = data_input[:nb // 5]
- self.train_input = data_input[nb // 5:]
-
- def batches(self, split = 'train'):
- assert split in { 'train', 'test' }
- if split == 'train':
- for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = 'epoch'):
- yield batch
- else:
- for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = 'epoch'):
- yield batch
-
- def vocabulary_size(self):
- return len(self.token2id)
-
- def produce_results(self, n_epoch, model, nb_tokens = 50):
- img = [ ]
- nb_per_primer = 8
- for primer in [
- 'red above green <sep> green top <sep> blue right of red <img>',
- 'there is red <sep> there is yellow <sep> there is blue <img>',
- 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
- 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
- ]:
-
- for k in range(nb_per_primer):
- t_primer = primer.strip().split(' ')
- t_generated = [ ]
-
- for j in range(nb_tokens):
- t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
- input = torch.tensor(t, device = self.device)
- output = model(input)
- logits = output[0, -1]
- if args.synthesis_sampling:
- dist = torch.distributions.categorical.Categorical(logits = logits)
- t = dist.sample()
- else:
- t = logits.argmax()
- t_generated.append(self.id2token[t.item()])
-
- descr = [ ' '.join(t_primer + t_generated) ]
- img += [ picoclvr.descr2img(descr) ]
-
- img = torch.cat(img, 0)
- file_name = f'result_picoclvr_{n_epoch:04d}.png'
- torchvision.utils.save_image(img / 255.,
- file_name, nrow = nb_per_primer, pad_value = 0.8)
- log_string(f'wrote {file_name}')
-
-######################################################################
-
-class TaskWiki103(Task):
-
- def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
- device = torch.device('cpu')):
-
- self.batch_size = batch_size
- self.len_min = len_min
- self.len_max = len_max
- self.min_freq = min_freq
- self.device = device
-
- self.tokenizer = torchtext.data.get_tokenizer('basic_english')
- train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
-
- # Mostly for debug
- if args.data_size > 0:
- train_iter = itertools.islice(train_iter, args.data_size)
-
- def yield_tokens():
- for l in tqdm.tqdm(train_iter, desc = 'vocab'):
- yield self.tokenizer(l)
-
- self.vocab = torchtext.vocab.build_vocab_from_iterator(
- yield_tokens(),
- specials = [ '<unk>', '<non>' ],
- min_freq = self.min_freq
- )
-
- self.vocab.set_default_index(self.vocab[ '<unk>' ])
-
- def tensorize(self, s):
- a = max(len(x) for x in s)
- return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
-
- def yield_batches(self, ds):
- s = [ ]
- for l in ds:
- q = self.tokenizer(l)
- if len(q) >= self.len_min and len(q) <= self.len_max:
- s += [ q ]
- if len(s) == self.batch_size:
- yield self.tensorize(s)
- s = [ ]
-
- if len(s) > 0:
- yield self.tensorize(s)
-
- def batches(self, split = 'train'):
- data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
-
- # Mostly for debug
- if args.data_size > 0:
- data_iter = itertools.islice(data_iter, args.data_size)
-
- return self.yield_batches(tqdm.tqdm(data_iter, desc = 'epoch'))
-
- def vocabulary_size(self):
- return len(self.vocab)
-
- def produce_results(self, n_epoch, model, nb_tokens = 50):
- file_name = f'result_wiki103_{n_epoch:04d}.txt'
-
- with open(file_name, 'w') as outfile:
- for primer in [
- 'the cat is hunting a',
- 'paris is the capital',
- 'cars are convenient',
- 'the difference between men and women is',
- 'the object was blue all over and green all over it was',
- 'cherries are red and lemons are',
- 'cherries are sweet and lemons are',
- 'two plus three equals',
- 'deep learning is',
- ]:
- t_primer = self.tokenizer(primer)
- t_generated = [ ]
-
- for j in range(nb_tokens):
-
- input = self.tensorize([ t_primer + t_generated ]).to(self.device)
- output = model(input)
- logits = output[0, -1]
- if args.synthesis_sampling:
- dist = torch.distributions.categorical.Categorical(logits = logits)
- t = dist.sample()
- else:
- t = logits.argmax()
- t_generated.append(self.vocab.lookup_token(t))
- if t_generated[-1] == '<non>': break
-
- s = ' '.join(t_generated)
-
- outfile.write(f'<{primer}> {s}\n')
-
- log_string(f'wrote {file_name}')
-
-######################################################################
-
-class TaskMNIST(Task):
-
- def __init__(self, batch_size, device = torch.device('cpu')):
- self.device = device
- self.batch_size = batch_size
-
- def batches(self, split = 'train'):
- assert split in { 'train', 'test' }
- data_set = torchvision.datasets.MNIST(
- root = './data', train = (split == 'train'),
- download = True
- )
- data_input = data_set.data.view(-1, 28 * 28).long()
- if args.data_size >= 0:
- data_input = data_input[:args.data_size]
- for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = 'epoch'):
- yield batch
-
- def vocabulary_size(self):
- return 256
-
- def produce_results(self, n_epoch, model, nb_samples = 64):
- results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
- for input in results.split(self.batch_size):
- for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
- output = model(input)
- logits = output[:, s]
- if args.synthesis_sampling:
- dist = torch.distributions.categorical.Categorical(logits = logits)
- t = dist.sample()
- else:
- t = logits.argmax(1)
- input[:, s + 1] = t
-
- image_name = f'result_mnist_{n_epoch:04d}.png'
- torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
- image_name, nrow = 16, pad_value = 0.8)
- log_string(f'wrote {image_name}')
-
-######################################################################
-
-def check_causality(model):
- #m = model[1:]
- input = torch.rand(1, 5, dim_model).requires_grad_()
- output = m(input)
- a = torch.zeros(output.size(1), input.size(1))
- for k in range(output.size(1)):
- for d in range(output.size(2)):
- g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
- a[k] += g.squeeze(0).pow(2).sum(1)
- print(a)
-
-######################################################################
-
-log_string(f'device {device}')
-
-if args.data == 'wiki103':
- task = TaskWiki103(batch_size = args.batch_size, device = device)
-elif args.data == 'mnist':
- task = TaskMNIST(batch_size = args.batch_size, device = device)
-elif args.data == 'picoclvr':
- task = TaskPicoCLVR(batch_size = args.batch_size, device = device)
-else:
- raise ValueError(f'Unknown dataset {args.data}.')
-
-vocabulary_size = task.vocabulary_size()
-
-log_string(f'vocabulary_size {vocabulary_size}')
-
-##############################
-
-model = MyGPT(
- vocabulary_size = vocabulary_size,
- dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
- nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
-)
-
-nb_parameters = sum(p.numel() for p in model.parameters())
-log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
-
-model.to(device)
-
-######################################################################
-
-if args.optim == 'sgd':
- optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
-elif args.optim == 'adam':
- optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
-elif args.optim == 'adamw':
- optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
-else:
- raise ValueError(f'Unknown optimizer {args.optim}.')
-
-for k in range(args.nb_epochs):
-
- model.train()
-
- nb_train_samples, acc_train_loss = 0, 0.0
-
- for input in task.batches(split = 'train'):
- input = input.to(device)
- output = model(input)
- loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
- acc_train_loss += loss.item() * input.size(0)
- nb_train_samples += input.size(0)
-
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- with torch.autograd.no_grad():
-
- model.eval()
-
- nb_test_samples, acc_test_loss = 0, 0.0
-
- for input in task.batches(split = 'test'):
- input = input.to(device)
- output = model(input)
- loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
- acc_test_loss += loss.item() * input.size(0)
- nb_test_samples += input.size(0)
+if __name__ == '__main__':
+ vocabulary_size = 10
+ x = torch.randint(vocabulary_size, (25, 100))
- log_string(f'perplexity {k+1} train {math.exp(min(100, acc_train_loss/nb_train_samples))} test {math.exp(min(100, acc_test_loss/nb_test_samples))}')
+ model = MyGPT(
+ vocabulary_size = vocabulary_size,
+ dim_model = 16, dim_keys = 50, dim_hidden = 100,
+ nb_heads = 2, nb_blocks = 3,
+ dropout = 0.1
+ )
- task.produce_results(k, model)
+ y = model(x)
######################################################################