Added default configurations and reformated with black.
[mygpt.git] / mygpt.py
index 13fbe8e..a6b257c 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -5,95 +5,17 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
-import math, sys, argparse, time, tqdm, itertools
+import math
+
+import torch
 
-import torch, torchtext, torchvision
 from torch import nn
 from torch.nn import functional as F
 
-######################################################################
-
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
-######################################################################
-
-parser = argparse.ArgumentParser(description = 'My own GPT.')
-
-parser.add_argument('--log_filename',
-                    type = str, default = 'train.log')
-
-parser.add_argument('--download',
-                    type = bool, default = False)
-
-parser.add_argument('--seed',
-                    type = int, default = 0)
-
-parser.add_argument('--nb_epochs',
-                    type = int, default = 100)
-
-parser.add_argument('--batch_size',
-                    type = int, default = 25)
-
-parser.add_argument('--data',
-                    type = str, default = 'wiki103')
-
-parser.add_argument('--data_size',
-                    type = int, default = -1)
-
-parser.add_argument('--optim',
-                    type = str, default = 'adam')
-
-parser.add_argument('--learning_rate',
-                    type = float, default = 1e-4)
-
-parser.add_argument('--dim_model',
-                    type = int, default = 512)
-
-parser.add_argument('--dim_keys',
-                    type = int, default = 64)
-
-parser.add_argument('--dim_hidden',
-                    type = int, default = 2048)
-
-parser.add_argument('--nb_heads',
-                    type = int, default = 8)
-
-parser.add_argument('--nb_blocks',
-                    type = int, default = 12)
-
-parser.add_argument('--dropout',
-                    type = float, default = 0.1)
-
-parser.add_argument('--synthesis_sampling',
-                    type = bool, default = True)
-
-######################################################################
-
-args = parser.parse_args()
-
-log_file = open(args.log_filename, 'w')
-
-if args.seed >= 0:
-    torch.manual_seed(args.seed)
-
-######################################################################
-
-def log_string(s):
-    t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
-
-    if log_file is not None:
-        log_file.write(t + s + '\n')
-        log_file.flush()
-
-    print(t + s)
-    sys.stdout.flush()
-
-for n in vars(args):
-    log_string(f'args.{n} {getattr(args, n)}')
-
 ##############################
 
-class Residual(nn.Module):
+
+class WithResidual(nn.Module):
     def __init__(self, *f):
         super().__init__()
         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
@@ -101,58 +23,85 @@ class Residual(nn.Module):
     def forward(self, x):
         return x + self.f(x)
 
+
 ##############################
 
-class PositionalEncoding(nn.Module):
+
+class AddPositionalEncoding(nn.Module):
     def __init__(self, len_max):
         super().__init__()
         self.len_max = len_max
 
-    # From Vaswani et al 2018
-    # PE_{t,2i}   = sin(t/(L^{2i/D}))
-    # PE_{t,2i+1} = cos(t/(L^{2i/D}))
+    # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
     def forward(self, x):
-        t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None]
-        j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :]
-        k = j%2
-        return x + torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)[None, :, :]
+        t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None]
+        j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :]
+        k = j % 2
+        pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k)
+        return x + pe
+
 
 ##############################
 
+
 class QKVAttention(nn.Module):
-    def __init__(self, dim_in, dim_qk, dim_v, nb_heads = 1, causal = False, attention_dropout = 0.0):
+    def __init__(
+        self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0
+    ):
         super().__init__()
 
         def randw(*d):
-            return nn.Parameter(torch.empty(*d).normal_(0, 1 / math.sqrt(d[-1])))
+            return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
 
-        self.wq = randw(nb_heads, dim_qk, dim_in)
-        self.wk = randw(nb_heads, dim_qk, dim_in)
-        self.wv = randw(nb_heads, dim_v, dim_in)
         self.causal = causal
         self.attention_dropout = attention_dropout
 
-    def forward(self, x):
-        q = torch.einsum('ntc,hdc->nhtd', x, self.wq)
-        k = torch.einsum('ntc,hdc->nhtd', x, self.wk)
-        v = torch.einsum('ntc,hdc->nhtd', x, self.wv)
-        r = math.sqrt(q.size(3))
-        a = torch.einsum('nhtd,nhsd->nhts', q, k).div(r)
+        self.w_q = randw(nb_heads, dim_qk, dim_in)
+        self.w_k = randw(nb_heads, dim_qk, dim_in)
+        self.w_v = randw(nb_heads, dim_v, dim_in)
+        self.w_o = randw(dim_v * nb_heads, dim_in)
+
+    def forward(self, x_q, x_kv=None):
+        if x_kv is None:
+            x_kv = x_q
+
+        q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q)
+        k = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_k)
+        v = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_v)
+
+        a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
+
         if self.causal:
-            mask = torch.tril(q.new_ones(a.size(2), a.size(3)))[None, None, :, :] == 0
-            a = a.masked_fill(mask, float('-inf'))
-        a = a.softmax(dim = 3)
+            forbidden_attention = (
+                torch.arange(a.size(2), device=q.device)[None, None, :, None]
+                < torch.arange(a.size(3), device=q.device)[None, None, None, :]
+            )
+            a = a.masked_fill(forbidden_attention, float("-inf"))
+
+        a = a.softmax(dim=3)
         a = F.dropout(a, self.attention_dropout, self.training)
-        y = torch.einsum('nhts,nhsd->nhtd', a, v)
-        return y.permute(0, 2, 1, 3).flatten(2) # nhtd -> nt(hd)
+        y = torch.einsum("nhts,nhsd->nthd", a, v).flatten(2)
+
+        y = y @ self.w_o
+
+        return y
+
 
 ##############################
 
+
 class MyGPT(nn.Module):
-    def __init__(self,
-                 vocabulary_size,
-                 dim_model, dim_keys, dim_hidden,
-                 nb_heads, nb_blocks, dropout = 0.):
+    def __init__(
+        self,
+        vocabulary_size,
+        dim_model,
+        dim_keys,
+        dim_hidden,
+        nb_heads,
+        nb_blocks,
+        dropout=0.0,
+        len_max=1e5,
+    ):
 
         super().__init__()
 
@@ -161,358 +110,71 @@ class MyGPT(nn.Module):
         self.embedding = nn.Sequential(
             nn.Embedding(vocabulary_size, dim_model),
             nn.Dropout(dropout),
-            PositionalEncoding(len_max = 1e5),
+            AddPositionalEncoding(len_max),
         )
 
-        trunk_blocks = [ ]
+        trunk_blocks = []
 
         for _ in range(nb_blocks):
             trunk_blocks += [
-                Residual(
-                    nn.LayerNorm(dim_model),
+                WithResidual(
+                    nn.LayerNorm((dim_model,)),
                     QKVAttention(
-                        dim_in = dim_model,
-                        dim_qk = dim_keys, dim_v = dim_model // nb_heads,
-                        nb_heads = nb_heads,
-                        causal = True, attention_dropout = dropout
+                        dim_in=dim_model,
+                        dim_qk=dim_keys,
+                        dim_v=dim_model // nb_heads,
+                        nb_heads=nb_heads,
+                        causal=True,
+                        attention_dropout=dropout,
                     ),
-                    nn.Linear(in_features = dim_model, out_features = dim_model),
                 ),
-                Residual(
-                    nn.LayerNorm(dim_model),
-                    nn.Linear(in_features = dim_model, out_features = dim_hidden),
+                WithResidual(
+                    nn.LayerNorm((dim_model,)),
+                    nn.Linear(in_features=dim_model, out_features=dim_hidden),
                     nn.ReLU(),
-                    nn.Linear(in_features = dim_hidden, out_features = dim_model),
+                    nn.Linear(in_features=dim_hidden, out_features=dim_model),
                     nn.Dropout(dropout),
                 ),
             ]
 
         self.trunk = nn.Sequential(*trunk_blocks)
 
-        self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size)
+        self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+
+        with torch.no_grad():
+            for m in self.modules():
+                if isinstance(m, nn.Embedding):
+                    m.weight.normal_(mean=0, std=2e-2)
+                elif isinstance(m, nn.LayerNorm):
+                    m.bias.zero_()
+                    m.weight.fill_(1.0)
 
     def forward(self, x):
+        x = F.pad(x, (1, -1))
         x = self.embedding(x)
         x = self.trunk(x)
         x = self.readout(x)
         return x
 
-######################################################################
-
-class Task:
-    def batches(self, split = 'train'):
-        pass
-
-    def vocabulary_size(self):
-        pass
-
-    def produce_results(self, n_epoch, model, nb_tokens = 50):
-        pass
 
 ######################################################################
 
-import picoclvr
-
-class TaskPicoCLVR(Task):
-
-    def __init__(self, batch_size, height = 6, width = 8, device = torch.device('cpu')):
-        self.batch_size = batch_size
-        self.device = device
-        nb = args.data_size if args.data_size > 0 else 250000
-
-        descr = picoclvr.generate(nb, height = height, width = width)
-        descr = [ s.strip().split(' ') for s in descr ]
-        l = max([ len(s) for s in descr ])
-        descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
-
-        tokens = set()
-        for s in descr:
-            for t in s: tokens.add(t)
-        self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
-        self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
-
-        t = [ [ self.token2id[u] for u in s ] for s in descr ]
-        data_input = torch.tensor(t, device = self.device)
-
-        self.test_input = data_input[:nb // 5]
-        self.train_input = data_input[nb // 5:]
-
-    def batches(self, split = 'train'):
-        assert split in { 'train', 'test' }
-        if split == 'train':
-            for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = 'epoch'):
-                yield batch
-        else:
-            for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = 'epoch'):
-                yield batch
-
-    def vocabulary_size(self):
-        return len(self.token2id)
-
-    def produce_results(self, n_epoch, model, nb_tokens = 50):
-        img = [ ]
-        nb_per_primer = 8
-
-        for primer in [
-                'red above green <sep> green top <sep> blue right of red <img>',
-                'there is red <sep> there is yellow <sep> there is blue <img>',
-                'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
-                'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
-        ]:
-
-            for k in range(nb_per_primer):
-                t_primer = primer.strip().split(' ')
-                t_generated = [ ]
-
-                for j in range(nb_tokens):
-                    t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
-                    input = torch.tensor(t, device = self.device)
-                    output = model(input)
-                    logits = output[0, -1]
-                    if args.synthesis_sampling:
-                        dist = torch.distributions.categorical.Categorical(logits = logits)
-                        t = dist.sample()
-                    else:
-                        t = logits.argmax()
-                    t_generated.append(self.id2token[t.item()])
-
-                descr = [ ' '.join(t_primer + t_generated) ]
-                img += [ picoclvr.descr2img(descr) ]
-
-        img = torch.cat(img, 0)
-        file_name = f'result_picoclvr_{n_epoch:04d}.png'
-        torchvision.utils.save_image(img / 255.,
-                                     file_name, nrow = nb_per_primer, pad_value = 0.8)
-        log_string(f'wrote {file_name}')
-
-######################################################################
-
-class TaskWiki103(Task):
-
-    def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
-                 device = torch.device('cpu')):
-
-        self.batch_size = batch_size
-        self.len_min = len_min
-        self.len_max = len_max
-        self.min_freq = min_freq
-        self.device = device
-
-        self.tokenizer = torchtext.data.get_tokenizer('basic_english')
-        train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
-
-        # Mostly for debug
-        if args.data_size > 0:
-            train_iter = itertools.islice(train_iter, args.data_size)
-
-        def yield_tokens():
-            for l in tqdm.tqdm(train_iter, desc = 'vocab'):
-                yield self.tokenizer(l)
-
-        self.vocab = torchtext.vocab.build_vocab_from_iterator(
-            yield_tokens(),
-            specials = [ '<unk>', '<non>' ],
-            min_freq = self.min_freq
-        )
-
-        self.vocab.set_default_index(self.vocab[ '<unk>' ])
-
-    def tensorize(self, s):
-        a = max(len(x) for x in s)
-        return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
-
-    def yield_batches(self, ds):
-        s = [ ]
-        for l in ds:
-            q = self.tokenizer(l)
-            if len(q) >= self.len_min and len(q) <= self.len_max:
-                s += [ q ]
-                if len(s) == self.batch_size:
-                    yield self.tensorize(s)
-                    s = [ ]
-
-        if len(s) > 0:
-            yield self.tensorize(s)
-
-    def batches(self, split = 'train'):
-        data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
-
-        # Mostly for debug
-        if args.data_size > 0:
-            data_iter = itertools.islice(data_iter, args.data_size)
-
-        return self.yield_batches(tqdm.tqdm(data_iter, desc = 'epoch'))
-
-    def vocabulary_size(self):
-        return len(self.vocab)
-
-    def produce_results(self, n_epoch, model, nb_tokens = 50):
-        file_name = f'result_wiki103_{n_epoch:04d}.txt'
-
-        with open(file_name, 'w') as outfile:
-             for primer in [
-                     'the cat is hunting a',
-                     'paris is the capital',
-                     'cars are convenient',
-                     'the difference between men and women is',
-                     'the object was blue all over and green all over it was',
-                     'cherries are red and lemons are',
-                     'cherries are sweet and lemons are',
-                     'two plus three equals',
-                     'deep learning is',
-             ]:
-                 t_primer = self.tokenizer(primer)
-                 t_generated = [ ]
-
-                 for j in range(nb_tokens):
-
-                     input = self.tensorize([ t_primer + t_generated ]).to(self.device)
-                     output = model(input)
-                     logits = output[0, -1]
-                     if args.synthesis_sampling:
-                         dist = torch.distributions.categorical.Categorical(logits = logits)
-                         t = dist.sample()
-                     else:
-                         t = logits.argmax()
-                     t_generated.append(self.vocab.lookup_token(t))
-                     if t_generated[-1] == '<non>': break
-
-                 s = ' '.join(t_generated)
-
-                 outfile.write(f'<{primer}> {s}\n')
-
-        log_string(f'wrote {file_name}')
-
-######################################################################
-
-class TaskMNIST(Task):
-
-    def __init__(self, batch_size, device = torch.device('cpu')):
-        self.device = device
-        self.batch_size = batch_size
-
-    def batches(self, split = 'train'):
-        assert split in { 'train', 'test' }
-        data_set = torchvision.datasets.MNIST(
-            root = './data', train = (split == 'train'),
-            download = True
-        )
-        data_input = data_set.data.view(-1, 28 * 28).long()
-        if args.data_size >= 0:
-            data_input = data_input[:args.data_size]
-        for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = 'epoch'):
-            yield batch
-
-    def vocabulary_size(self):
-        return 256
-
-    def produce_results(self, n_epoch, model, nb_samples = 64):
-        results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
-        for input in results.split(self.batch_size):
-            for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
-                output = model(input)
-                logits = output[:, s]
-                if args.synthesis_sampling:
-                    dist = torch.distributions.categorical.Categorical(logits = logits)
-                    t = dist.sample()
-                else:
-                    t = logits.argmax(1)
-                input[:, s + 1] = t
-
-        image_name = f'result_mnist_{n_epoch:04d}.png'
-        torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
-                                     image_name, nrow = 16, pad_value = 0.8)
-        log_string(f'wrote {image_name}')
-
-######################################################################
-
-def check_causality(model):
-    #m = model[1:]
-    input = torch.rand(1, 5, dim_model).requires_grad_()
-    output = m(input)
-    a = torch.zeros(output.size(1), input.size(1))
-    for k in range(output.size(1)):
-        for d in range(output.size(2)):
-            g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
-            a[k] += g.squeeze(0).pow(2).sum(1)
-    print(a)
-
-######################################################################
-
-log_string(f'device {device}')
-
-if args.data == 'wiki103':
-    task = TaskWiki103(batch_size = args.batch_size, device = device)
-elif args.data == 'mnist':
-    task = TaskMNIST(batch_size = args.batch_size, device = device)
-elif args.data == 'picoclvr':
-    task = TaskPicoCLVR(batch_size = args.batch_size, device = device)
-else:
-    raise ValueError(f'Unknown dataset {args.data}.')
-
-vocabulary_size = task.vocabulary_size()
-
-log_string(f'vocabulary_size {vocabulary_size}')
-
-##############################
-
-model = MyGPT(
-    vocabulary_size = vocabulary_size,
-    dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
-    nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
-)
-
-nb_parameters = sum(p.numel() for p in model.parameters())
-log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
-
-model.to(device)
-
-######################################################################
-
-if args.optim == 'sgd':
-    optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
-elif args.optim == 'adam':
-    optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
-elif args.optim == 'adamw':
-    optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
-else:
-    raise ValueError(f'Unknown optimizer {args.optim}.')
-
-for k in range(args.nb_epochs):
-
-    model.train()
-
-    nb_train_samples, acc_train_loss = 0, 0.0
-
-    for input in task.batches(split = 'train'):
-        input = input.to(device)
-        output = model(input)
-        loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
-        acc_train_loss += loss.item() * input.size(0)
-        nb_train_samples += input.size(0)
-
-        optimizer.zero_grad()
-        loss.backward()
-        optimizer.step()
-
-    with torch.autograd.no_grad():
-
-        model.eval()
-
-        nb_test_samples, acc_test_loss = 0, 0.0
-
-        for input in task.batches(split = 'test'):
-            input = input.to(device)
-            output = model(input)
-            loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
-            acc_test_loss += loss.item() * input.size(0)
-            nb_test_samples += input.size(0)
+if __name__ == "__main__":
+    print("Basic check.")
 
-        train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
-        test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
+    vocabulary_size = 10
+    x = torch.randint(vocabulary_size, (25, 100))
 
-        log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
+    model = MyGPT(
+        vocabulary_size=vocabulary_size,
+        dim_model=18,
+        dim_keys=50,
+        dim_hidden=100,
+        nb_heads=2,
+        nb_blocks=3,
+        dropout=0.1,
+    )
 
-        task.produce_results(k, model)
+    y = model(x)
 
 ######################################################################