Initial commit
authorFrancois Fleuret <francois@fleuret.org>
Sun, 24 Apr 2022 08:18:51 +0000 (10:18 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Sun, 24 Apr 2022 08:18:51 +0000 (10:18 +0200)
mygpt.py [new file with mode: 0755]
picoclvr.py [new file with mode: 0755]

diff --git a/mygpt.py b/mygpt.py
new file mode 100755 (executable)
index 0000000..970ee7b
--- /dev/null
+++ b/mygpt.py
@@ -0,0 +1,514 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math, sys, argparse, time, tqdm, itertools
+
+import torch, torchtext, torchvision
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+######################################################################
+
+parser = argparse.ArgumentParser(description = 'My own GPT.')
+
+parser.add_argument('--log_filename',
+                    type = str, default = 'train.log')
+
+parser.add_argument('--download',
+                    type = bool, default = False)
+
+parser.add_argument('--seed',
+                    type = int, default = 0)
+
+parser.add_argument('--nb_epochs',
+                    type = int, default = 100)
+
+parser.add_argument('--batch_size',
+                    type = int, default = 25)
+
+parser.add_argument('--data',
+                    type = str, default = 'wiki103')
+
+parser.add_argument('--data_size',
+                    type = int, default = -1)
+
+parser.add_argument('--optim',
+                    type = str, default = 'adam')
+
+parser.add_argument('--learning_rate',
+                    type = float, default = 1e-4)
+
+parser.add_argument('--dim_model',
+                    type = int, default = 512)
+
+parser.add_argument('--dim_keys',
+                    type = int, default = 64)
+
+parser.add_argument('--dim_hidden',
+                    type = int, default = 2048)
+
+parser.add_argument('--nb_heads',
+                    type = int, default = 8)
+
+parser.add_argument('--nb_blocks',
+                    type = int, default = 12)
+
+parser.add_argument('--dropout',
+                    type = float, default = 0.1)
+
+parser.add_argument('--synthesis_sampling',
+                    type = bool, default = True)
+
+######################################################################
+
+args = parser.parse_args()
+
+log_file = open(args.log_filename, 'w')
+
+if args.seed >= 0:
+    torch.manual_seed(args.seed)
+
+######################################################################
+
+def log_string(s):
+    t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
+
+    if log_file is not None:
+        log_file.write(t + s + '\n')
+        log_file.flush()
+
+    print(t + s)
+    sys.stdout.flush()
+
+for n in vars(args):
+    log_string(f'args.{n} {getattr(args, n)}')
+
+##############################
+
+class Residual(nn.Module):
+    def __init__(self, *f):
+        super().__init__()
+        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+    def forward(self, x):
+        return x + self.f(x)
+
+##############################
+
+class PositionalEncoding(nn.Module):
+    def __init__(self, len_max):
+        super().__init__()
+        self.len_max = len_max
+
+    # From Vaswani et al 2018
+    # PE_{t,2i}   = sin(t/(L^{2i/D}))
+    # PE_{t,2i+1} = cos(t/(L^{2i/D}))
+    def forward(self, x):
+        t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None]
+        j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :]
+        k = j%2
+        return x + torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)[None, :, :]
+
+##############################
+
+class QKVAttention(nn.Module):
+    def __init__(self, dim_in, dim_qk, dim_v, nb_heads = 1, causal = False, attention_dropout = 0.0):
+        super().__init__()
+
+        def randw(*d):
+            return nn.Parameter(torch.empty(*d).normal_(0, 1 / math.sqrt(d[-1])))
+
+        self.wq = randw(nb_heads, dim_qk, dim_in)
+        self.wk = randw(nb_heads, dim_qk, dim_in)
+        self.wv = randw(nb_heads, dim_v, dim_in)
+        self.causal = causal
+        self.attention_dropout = attention_dropout
+
+    def forward(self, x):
+        q = torch.einsum('ntc,hdc->nhtd', x, self.wq)
+        k = torch.einsum('ntc,hdc->nhtd', x, self.wk)
+        v = torch.einsum('ntc,hdc->nhtd', x, self.wv)
+        r = math.sqrt(q.size(3))
+        a = torch.einsum('nhtd,nhsd->nhts', q, k).div(r)
+        if self.causal:
+            mask = torch.tril(q.new_ones(a.size(2), a.size(3)))[None, None, :, :] == 0
+            a = a.masked_fill(mask, float('-inf'))
+        a = a.softmax(dim = 3)
+        a = F.dropout(a, self.attention_dropout, self.training)
+        y = torch.einsum('nhts,nhsd->nhtd', a, v)
+        return y.permute(0, 2, 1, 3).flatten(2) # nhtd -> nt(hd)
+
+##############################
+
+class MyGPT(nn.Module):
+    def __init__(self,
+                 vocabulary_size,
+                 dim_model, dim_keys, dim_hidden,
+                 nb_heads, nb_blocks, dropout = 0.):
+
+        super().__init__()
+
+        assert dim_model % nb_heads == 0
+
+        self.embedding = nn.Sequential(
+            nn.Embedding(vocabulary_size, dim_model),
+            nn.Dropout(dropout),
+            PositionalEncoding(len_max = 1e5),
+        )
+
+        trunk_blocks = [ ]
+
+        for _ in range(nb_blocks):
+            trunk_blocks += [
+                Residual(
+                    nn.LayerNorm(dim_model),
+                    QKVAttention(
+                        dim_in = dim_model,
+                        dim_qk = dim_keys, dim_v = dim_model // nb_heads,
+                        nb_heads = nb_heads,
+                        causal = True, attention_dropout = dropout
+                    ),
+                    nn.Linear(in_features = dim_model, out_features = dim_model),
+                ),
+                Residual(
+                    nn.LayerNorm(dim_model),
+                    nn.Linear(in_features = dim_model, out_features = dim_hidden),
+                    nn.ReLU(),
+                    nn.Linear(in_features = dim_hidden, out_features = dim_model),
+                    nn.Dropout(dropout),
+                ),
+            ]
+
+        self.trunk = nn.Sequential(*trunk_blocks)
+
+        self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size)
+
+    def forward(self, x):
+        x = self.embedding(x)
+        x = self.trunk(x)
+        x = self.readout(x)
+        return x
+
+######################################################################
+
+class Task:
+    def batches(self, split = 'train'):
+        pass
+
+    def vocabulary_size(self):
+        pass
+
+    def produce_results(self, n_epoch, model, nb_tokens = 50):
+        pass
+
+######################################################################
+
+import picoclvr
+
+class TaskPicoCLVR(Task):
+
+    def __init__(self, batch_size, height = 6, width = 8, device = torch.device('cpu')):
+        self.batch_size = batch_size
+        self.device = device
+        nb = args.data_size if args.data_size > 0 else 250000
+
+        descr = picoclvr.generate(nb, height = height, width = width)
+        descr = [ s.strip().split(' ') for s in descr ]
+        l = max([ len(s) for s in descr ])
+        descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
+
+        tokens = set()
+        for s in descr:
+            for t in s: tokens.add(t)
+        self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
+        self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
+
+        t = [ [ self.token2id[u] for u in s ] for s in descr ]
+        data_input = torch.tensor(t, device = self.device)
+
+        self.test_input = data_input[:nb // 5]
+        self.train_input = data_input[nb // 5:]
+
+    def batches(self, split = 'train'):
+        assert split in { 'train', 'test' }
+        if split == 'train':
+            for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = 'epoch'):
+                yield batch
+        else:
+            for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = 'epoch'):
+                yield batch
+
+    def vocabulary_size(self):
+        return len(self.token2id)
+
+    def produce_results(self, n_epoch, model, nb_tokens = 50):
+        img = [ ]
+        nb_per_primer = 8
+        for primer in [
+                'red above green <sep> green top <sep> blue right of red <img>',
+                'there is red <sep> there is yellow <sep> there is blue <img>',
+                'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
+                'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
+        ]:
+
+            for k in range(nb_per_primer):
+                t_primer = primer.strip().split(' ')
+                t_generated = [ ]
+
+                for j in range(nb_tokens):
+                    t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
+                    input = torch.tensor(t, device = self.device)
+                    output = model(input)
+                    logits = output[0, -1]
+                    if args.synthesis_sampling:
+                        dist = torch.distributions.categorical.Categorical(logits = logits)
+                        t = dist.sample()
+                    else:
+                        t = logits.argmax()
+                    t_generated.append(self.id2token[t.item()])
+
+                descr = [ ' '.join(t_primer + t_generated) ]
+                img += [ picoclvr.descr2img(descr) ]
+
+        img = torch.cat(img, 0)
+        file_name = f'result_picoclvr_{n_epoch:04d}.png'
+        torchvision.utils.save_image(img / 255.,
+                                     file_name, nrow = nb_per_primer, pad_value = 0.8)
+        log_string(f'wrote {file_name}')
+
+######################################################################
+
+class TaskWiki103(Task):
+
+    def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
+                 device = torch.device('cpu')):
+
+        self.batch_size = batch_size
+        self.len_min = len_min
+        self.len_max = len_max
+        self.min_freq = min_freq
+        self.device = device
+
+        self.tokenizer = torchtext.data.get_tokenizer('basic_english')
+        train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
+
+        # Mostly for debug
+        if args.data_size > 0:
+            train_iter = itertools.islice(train_iter, args.data_size)
+
+        def yield_tokens():
+            for l in tqdm.tqdm(train_iter, desc = 'vocab'):
+                yield self.tokenizer(l)
+
+        self.vocab = torchtext.vocab.build_vocab_from_iterator(
+            yield_tokens(),
+            specials = [ '<unk>', '<non>' ],
+            min_freq = self.min_freq
+        )
+
+        self.vocab.set_default_index(self.vocab[ '<unk>' ])
+
+    def tensorize(self, s):
+        a = max(len(x) for x in s)
+        return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
+
+    def yield_batches(self, ds):
+        s = [ ]
+        for l in ds:
+            q = self.tokenizer(l)
+            if len(q) >= self.len_min and len(q) <= self.len_max:
+                s += [ q ]
+                if len(s) == self.batch_size:
+                    yield self.tensorize(s)
+                    s = [ ]
+
+        if len(s) > 0:
+            yield self.tensorize(s)
+
+    def batches(self, split = 'train'):
+        data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
+
+        # Mostly for debug
+        if args.data_size > 0:
+            data_iter = itertools.islice(data_iter, args.data_size)
+
+        return self.yield_batches(tqdm.tqdm(data_iter, desc = 'epoch'))
+
+    def vocabulary_size(self):
+        return len(self.vocab)
+
+    def produce_results(self, n_epoch, model, nb_tokens = 50):
+        file_name = f'result_wiki103_{n_epoch:04d}.txt'
+
+        with open(file_name, 'w') as outfile:
+             for primer in [
+                     'the cat is hunting a',
+                     'paris is the capital',
+                     'cars are convenient',
+                     'the difference between men and women is',
+                     'the object was blue all over and green all over it was',
+                     'cherries are red and lemons are',
+                     'cherries are sweet and lemons are',
+                     'two plus three equals',
+                     'deep learning is',
+             ]:
+                 t_primer = self.tokenizer(primer)
+                 t_generated = [ ]
+
+                 for j in range(nb_tokens):
+
+                     input = self.tensorize([ t_primer + t_generated ]).to(self.device)
+                     output = model(input)
+                     logits = output[0, -1]
+                     if args.synthesis_sampling:
+                         dist = torch.distributions.categorical.Categorical(logits = logits)
+                         t = dist.sample()
+                     else:
+                         t = logits.argmax()
+                     t_generated.append(self.vocab.lookup_token(t))
+                     if t_generated[-1] == '<non>': break
+
+                 s = ' '.join(t_generated)
+
+                 outfile.write(f'<{primer}> {s}\n')
+
+        log_string(f'wrote {file_name}')
+
+######################################################################
+
+class TaskMNIST(Task):
+
+    def __init__(self, batch_size, device = torch.device('cpu')):
+        self.device = device
+        self.batch_size = batch_size
+
+    def batches(self, split = 'train'):
+        assert split in { 'train', 'test' }
+        data_set = torchvision.datasets.MNIST(
+            root = './data', train = (split == 'train'),
+            download = True
+        )
+        data_input = data_set.data.view(-1, 28 * 28).long()
+        if args.data_size >= 0:
+            data_input = data_input[:args.data_size]
+        for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = 'epoch'):
+            yield batch
+
+    def vocabulary_size(self):
+        return 256
+
+    def produce_results(self, n_epoch, model, nb_samples = 64):
+        results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
+        for input in results.split(self.batch_size):
+            for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
+                output = model(input)
+                logits = output[:, s]
+                if args.synthesis_sampling:
+                    dist = torch.distributions.categorical.Categorical(logits = logits)
+                    t = dist.sample()
+                else:
+                    t = logits.argmax(1)
+                input[:, s + 1] = t
+
+        image_name = f'result_mnist_{n_epoch:04d}.png'
+        torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
+                                     image_name, nrow = 16, pad_value = 0.8)
+        log_string(f'wrote {image_name}')
+
+######################################################################
+
+def check_causality(model):
+    #m = model[1:]
+    input = torch.rand(1, 5, dim_model).requires_grad_()
+    output = m(input)
+    a = torch.zeros(output.size(1), input.size(1))
+    for k in range(output.size(1)):
+        for d in range(output.size(2)):
+            g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
+            a[k] += g.squeeze(0).pow(2).sum(1)
+    print(a)
+
+######################################################################
+
+log_string(f'device {device}')
+
+if args.data == 'wiki103':
+    task = TaskWiki103(batch_size = args.batch_size, device = device)
+elif args.data == 'mnist':
+    task = TaskMNIST(batch_size = args.batch_size, device = device)
+elif args.data == 'picoclvr':
+    task = TaskPicoCLVR(batch_size = args.batch_size, device = device)
+else:
+    raise ValueError(f'Unknown dataset {args.data}.')
+
+vocabulary_size = task.vocabulary_size()
+
+log_string(f'vocabulary_size {vocabulary_size}')
+
+##############################
+
+model = MyGPT(
+    vocabulary_size = vocabulary_size,
+    dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
+    nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
+)
+
+nb_parameters = sum(p.numel() for p in model.parameters())
+log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
+
+model.to(device)
+
+######################################################################
+
+if args.optim == 'sgd':
+    optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
+elif args.optim == 'adam':
+    optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
+elif args.optim == 'adamw':
+    optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
+else:
+    raise ValueError(f'Unknown optimizer {args.optim}.')
+
+for k in range(args.nb_epochs):
+
+    model.train()
+
+    nb_train_samples, acc_train_loss = 0, 0.0
+
+    for input in task.batches(split = 'train'):
+        input = input.to(device)
+        output = model(input)
+        loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
+        acc_train_loss += loss.item() * input.size(0)
+        nb_train_samples += input.size(0)
+
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+
+    with torch.autograd.no_grad():
+
+        model.eval()
+
+        nb_test_samples, acc_test_loss = 0, 0.0
+
+        for input in task.batches(split = 'test'):
+            input = input.to(device)
+            output = model(input)
+            loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
+            acc_test_loss += loss.item() * input.size(0)
+            nb_test_samples += input.size(0)
+
+        log_string(f'perplexity {k+1} train {math.exp(min(100, acc_train_loss/nb_train_samples))} test {math.exp(min(100, acc_test_loss/nb_test_samples))}')
+
+        task.produce_results(k, model)
+
+######################################################################
diff --git a/picoclvr.py b/picoclvr.py
new file mode 100755 (executable)
index 0000000..a194a1c
--- /dev/null
@@ -0,0 +1,102 @@
+#!/usr/bin/env python
+
+import torch, torchvision
+
+colors = [
+    [ 255, 255, 255 ],
+    [ 255,   0,   0 ],
+    [   0, 255,   0 ],
+    [   0,   0, 255 ],
+    [ 255, 255,   0 ],
+    [   0,   0,   0 ],
+]
+
+color_names = [
+    'white',
+    'red',
+    'green',
+    'blue',
+    'yellow',
+    'black',
+]
+
+color_tokens = dict( [ (n, c) for n, c in zip(color_names, colors) ] )
+
+def generate(nb, height = 6, width = 8, max_nb_statements = 10):
+
+    descr = [ ]
+
+    for n in range(nb):
+        nb = torch.randint(5, (1,)) + 1
+        shape_position = torch.randperm(height * width)[:nb]
+        shape_c = torch.randperm(5)[:nb] + 1
+        shape_i = shape_position.div(width, rounding_mode = 'floor')
+        shape_j = shape_position % width
+
+        img = [ 0 ] * height * width
+        for k in range(nb): img[shape_position[k]] = shape_c[k]
+
+        s = [ ]
+
+        for r, c in [ (k, color_names[shape_c[k]]) for k in range(nb) ]:
+            s += [ f'there is {c}' ]
+
+            if shape_i[r] >= height - height/4: s += [ f'{c} bottom' ]
+            if shape_i[r] < height/4: s += [ f'{c} top' ]
+            if shape_j[r] >= width - width/4: s += [ f'{c} right' ]
+            if shape_j[r] < width/4: s += [ f'{c} left' ]
+
+            for t, d in [ (k, color_names[shape_c[k]]) for k in range(nb) ]:
+                if shape_i[r] > shape_i[t]: s += [ f'{c} below {d}' ]
+                if shape_i[r] < shape_i[t]: s += [ f'{c} above {d}' ]
+                if shape_j[r] > shape_j[t]: s += [ f'{c} right of {d}' ]
+                if shape_j[r] < shape_j[t]: s += [ f'{c} left of {d}' ]
+
+        nb_statements = torch.randint(max_nb_statements, (1,)) + 1
+        s = ' <sep> '.join([ s[k] for k in torch.randperm(len(s))[:nb_statements] ] )
+        s += ' <img> ' + ' '.join([ f'{color_names[n]}' for n in img ])
+        descr += [ s ]
+
+    return descr
+
+######################################################################
+
+def descr2img(descr, height = 6, width = 8):
+
+    def token2color(t):
+        try:
+            return color_tokens[t]
+        except KeyError:
+            return [ 128, 128, 128 ]
+
+    def img_descr(x):
+        u = x.split('<img>', 1)
+        return u[1] if len(u) > 1 else ''
+
+    img = torch.full((len(descr), 3, height, width), 255)
+    d = [ img_descr(x) for x in descr ]
+    d = [ u.strip().split(' ')[:height * width] for u in d ]
+    d = [ u + [ '<unk>' ] * (height * width - len(u)) for u in d ]
+    d = [ [ token2color(t) for t in u ] for u in d ]
+    img = torch.tensor(d).permute(0, 2, 1)
+    img = img.reshape(img.size(0), 3, height, width)
+
+    return img
+
+######################################################################
+
+if __name__ == '__main__':
+    descr = generate(5)
+    img = descr2img(descr)
+    print(descr, img.size())
+    torchvision.utils.save_image(img / 255.,
+                                 'example.png', nrow = 16, pad_value = 0.8)
+
+    import time
+
+    start_time = time.perf_counter()
+    descr = generate(10000)
+    end_time = time.perf_counter()
+    print(f'{len(descr) / (end_time - start_time):.02f} samples per second')
+
+######################################################################