######################################################################
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
######################################################################
+parser = argparse.ArgumentParser(description="My own GPT.")
-parser = argparse.ArgumentParser(description = 'My own GPT.')
+parser.add_argument("--log_filename", type=str, default="train.log")
-parser.add_argument('--log_filename',
- type = str, default = 'train.log')
+parser.add_argument("--seed", type=int, default=0)
-parser.add_argument('--seed',
- type = int, default = 0)
+parser.add_argument("--nb_epochs", type=int, default=None)
-parser.add_argument('--nb_epochs',
- type = int, default = -1)
+parser.add_argument("--batch_size", type=int, default=25)
-parser.add_argument('--batch_size',
- type = int, default = 25)
+parser.add_argument("--data", type=str, default="wiki103")
-parser.add_argument('--data',
- type = str, default = 'wiki103')
+parser.add_argument("--data_size", type=int, default=None)
-parser.add_argument('--data_size',
- type = int, default = -1)
+parser.add_argument("--optim", type=str, default="adam")
-parser.add_argument('--optim',
- type = str, default = 'adam')
+parser.add_argument("--learning_rate", type=float, default=1e-3)
-parser.add_argument('--learning_rate',
- type = float, default = 1e-4)
+parser.add_argument("--learning_rate_end", type=float, default=1e-6)
-parser.add_argument('--dim_model',
- type = int, default = 512)
+parser.add_argument("--dim_model", type=int, default=None)
-parser.add_argument('--dim_keys',
- type = int, default = 64)
+parser.add_argument("--dim_keys", type=int, default=None)
-parser.add_argument('--dim_hidden',
- type = int, default = 2048)
+parser.add_argument("--dim_hidden", type=int, default=None)
-parser.add_argument('--nb_heads',
- type = int, default = 8)
+parser.add_argument("--nb_heads", type=int, default=None)
-parser.add_argument('--nb_blocks',
- type = int, default = 12)
+parser.add_argument("--nb_blocks", type=int, default=None)
-parser.add_argument('--dropout',
- type = float, default = 0.1)
+parser.add_argument("--dropout", type=float, default=0.1)
-parser.add_argument('--synthesis_sampling',
- action='store_true', default = True)
+parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
-parser.add_argument('--no_checkpoint',
- action='store_true', default = False)
+parser.add_argument("--no_checkpoint", action="store_true", default=False)
-parser.add_argument('--checkpoint_name',
- type = str, default = 'checkpoint.pth')
+parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
##############################
# picoclvr options
-parser.add_argument('--picoclvr_nb_colors',
- type = int, default = 5)
+parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
-parser.add_argument('--picoclvr_height',
- type = int, default = 12)
+parser.add_argument("--picoclvr_height", type=int, default=12)
-parser.add_argument('--picoclvr_width',
- type = int, default = 16)
+parser.add_argument("--picoclvr_width", type=int, default=16)
######################################################################
args = parser.parse_args()
-log_file = open(args.log_filename, 'w')
+log_file = open(args.log_filename, "w")
if args.seed >= 0:
torch.manual_seed(args.seed)
######################################################################
+
def log_string(s):
- t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
+ t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
if log_file is not None:
- log_file.write(t + s + '\n')
+ log_file.write(t + s + "\n")
log_file.flush()
print(t + s)
sys.stdout.flush()
+
for n in vars(args):
- log_string(f'args.{n} {getattr(args, n)}')
+ log_string(f"args.{n} {getattr(args, n)}")
+
+######################################################################
+
+default_args = {
+ "mnist": {
+ "nb_epochs": 10,
+ "dim_model": 64,
+ "dim_keys": 64,
+ "dim_hidden": 128,
+ "nb_heads": 4,
+ "nb_blocks": 6,
+ },
+ "mnist-debug": {
+ "nb_epochs": 2,
+ "data_size": 10000,
+ "dim_model": 8,
+ "dim_keys": 8,
+ "dim_hidden": 8,
+ "nb_heads": 2,
+ "nb_blocks": 4,
+ },
+ "wiki103": {
+ "nb_epochs": 25,
+ "dim_model": 512,
+ "dim_keys": 64,
+ "dim_hidden": 2048,
+ "nb_heads": 8,
+ "nb_blocks": 12,
+ },
+ "picoclvr": {
+ "nb_epochs": 25,
+ "dim_model": 512,
+ "dim_keys": 64,
+ "dim_hidden": 2048,
+ "nb_heads": 8,
+ "nb_blocks": 12,
+ },
+}
+
+if args.data in default_args:
+ for k, v in default_args[args.data].items():
+ if getattr(args, k) is None:
+ setattr(args, k, v)
######################################################################
+
def autoregression(
- model, batch_size,
- nb_samples, nb_tokens_to_generate, primer = None,
- device = torch.device('cpu')
+ model,
+ batch_size,
+ nb_samples,
+ nb_tokens_to_generate,
+ primer=None,
+ device=torch.device("cpu"),
):
results = torch.zeros(
- nb_samples, nb_tokens_to_generate,
- dtype = torch.int64, device = device
+ nb_samples, nb_tokens_to_generate, dtype=torch.int64, device=device
)
if primer is None:
results = torch.cat((primer, results), 1)
for input in results.split(batch_size):
- for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'):
+ for s in range(first, input.size(1)):
output = model(input)
logits = output[:, s]
- if args.synthesis_sampling:
- dist = torch.distributions.categorical.Categorical(logits = logits)
- t_next = dist.sample()
- else:
+ if args.deterministic_synthesis:
t_next = logits.argmax(1)
+ else:
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ t_next = dist.sample()
input[:, s] = t_next
return results
+
######################################################################
+
class Task:
- def batches(self, split = 'train'):
+ def batches(self, split="train"):
pass
def vocabulary_size(self):
pass
- def produce_results(self, n_epoch, model, nb_tokens = 50):
+ def produce_results(self, n_epoch, model):
pass
+
######################################################################
import picoclvr
-class TaskPicoCLVR(Task):
-
- def descr2tensor(self, descr):
- t = [ [ self.token2id[u] for u in s ] for s in descr ]
- return torch.tensor(t, device = self.device)
- def __init__(self, batch_size,
- height, width, nb_colors = 5,
- device = torch.device('cpu')):
+class TaskPicoCLVR(Task):
+ # Make a tensor from a list of strings
+ def tensorize(self, descr):
+ token_descr = [s.strip().split(" ") for s in descr]
+ l = max([len(s) for s in token_descr])
+ padded_token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
+ id_descr = [[self.token2id[u] for u in s] for s in padded_token_descr]
+ return torch.tensor(id_descr, device=self.device)
+
+ def trim(self, x, token="<nul>"):
+ n = self.token2id[token]
+ i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+ a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+ return x[:, a:b]
+
+ def __init__(
+ self, batch_size, height, width, nb_colors=5, device=torch.device("cpu")
+ ):
def generate_descr(nb):
- descr = picoclvr.generate(
- nb,
- height = self.height, width = self.width,
- nb_colors = nb_colors
+ return picoclvr.generate(
+ nb, height=self.height, width=self.width, nb_colors=nb_colors
)
- descr = [ s.strip().split(' ') for s in descr ]
- l = max([ len(s) for s in descr ])
- #descr = [ [ '<unk>' ] * (l - len(s)) + s for s in descr ]
- descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
-
- return descr
-
self.height = height
self.width = width
self.batch_size = batch_size
self.device = device
- nb = args.data_size if args.data_size > 0 else 250000
+ nb = args.data_size if args.data_size is not None else 250000
+ log_string(f"generating {nb} samples (can take some time)")
self.train_descr = generate_descr((nb * 4) // 5)
self.test_descr = generate_descr((nb * 1) // 5)
# Build the tokenizer
- tokens = set()
- for d in [ self.train_descr, self.test_descr ]:
+ tokens = {"<nul>"}
+ for d in [self.train_descr, self.test_descr]:
for s in d:
- for t in s: tokens.add(t)
- self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
- self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
+ for t in s.strip().split(" "):
+ tokens.add(t)
+ self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
+ self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
# Tokenize the train and test sets
- self.train_input = descr2tensor(self.train_descr)
- self.test_input = descr2tensor(self.test_descr)
-
- def batches(self, split = 'train'):
- assert split in { 'train', 'test' }
- if split == 'train':
- for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
- yield batch
- else:
- for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
- yield batch
+ self.train_input = self.tensorize(self.train_descr)
+ self.test_input = self.tensorize(self.test_descr)
+
+ def batches(self, split="train"):
+ assert split in {"train", "test"}
+ input = self.train_input if split == "train" else self.test_input
+ for batch in tqdm.tqdm(input.split(self.batch_size), desc=f"epoch-{split}"):
+ yield self.trim(batch)
def vocabulary_size(self):
return len(self.token2id)
- def generate(self, descr_primer, model, nb_tokens):
- results = autoregression(
- model, self.batch_size,
- 1, nb_tokens, primer = descr2tensor(descr_primer),
- device = self.device
- )
- return ' '.join([ self.id2token[t.item()] for t in results.flatten() ])
-
- def produce_results(self, n_epoch, model, nb_tokens = None):
- if nb_tokens is None:
- nb_tokens = self.height * self.width + 3
- result_descr = [ ]
- nb_per_primer = 8
-
- for descr_primer in [
- 'red above green <sep> green top <sep> blue right of red <img>',
- 'there is red <sep> there is yellow <sep> there is blue <img>',
- 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
- 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
- ]:
-
- for k in range(nb_per_primer):
- result_descr.append(self.generate(descr_primer, model, nb_tokens))
-
- img = [ picoclvr.descr2img(d, height = self.height, width = self.width)
- for d in result_descr ]
- img = torch.cat(img, 0)
- image_name = f'result_picoclvr_{n_epoch:04d}.png'
- torchvision.utils.save_image(
- img / 255.,
- image_name, nrow = nb_per_primer, pad_value = 0.8
+ def test_model(
+ self, n_epoch, model, primers_descr, nb_per_primer=1, generate_images=False
+ ):
+ nb_tokens_to_generate = self.height * self.width + 3
+ result_descr = []
+
+ for primer_descr in primers_descr:
+
+ results = autoregression(
+ model,
+ self.batch_size,
+ nb_samples=nb_per_primer,
+ nb_tokens_to_generate=nb_tokens_to_generate,
+ primer=self.tensorize([primer_descr]).expand(nb_per_primer, -1),
+ device=self.device,
+ )
+
+ l = [" ".join([self.id2token[t.item()] for t in r]) for r in results]
+ result_descr += l
+
+ np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
+
+ nb_requested_properties, _, nb_missing_properties = zip(*np)
+
+ log_string(
+ f"nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}"
)
- log_string(f'wrote {image_name}')
- np = picoclvr.nb_properties(
- result_descr,
- height = self.height, width = self.width
+ np = torch.tensor(np)
+ count = torch.empty(np[:, 0].max() + 1, np[:, 2].max() + 1, dtype=torch.int64)
+ for i in range(count.size(0)):
+ for j in range(count.size(1)):
+ count[i, j] = ((np[:, 0] == i).long() * (np[:, 2] == j).long()).sum()
+
+ if generate_images:
+ img = [
+ picoclvr.descr2img(d, height=self.height, width=self.width)
+ for d in result_descr
+ ]
+
+ img = torch.cat(img, 0)
+ image_name = f"result_picoclvr_{n_epoch:04d}.png"
+ torchvision.utils.save_image(
+ img / 255.0, image_name, nrow=nb_per_primer, pad_value=0.8
+ )
+ log_string(f"wrote {image_name}")
+
+ return count
+
+ def produce_results(self, n_epoch, model):
+ primers_descr = [
+ "red above green <sep> green top <sep> blue right of red <img>",
+ "there is red <sep> there is yellow <sep> there is blue <img>",
+ "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>",
+ "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>",
+ ]
+
+ self.test_model(
+ n_epoch, model, primers_descr, nb_per_primer=8, generate_images=True
)
- nb_requested_properties, _, nb_missing_properties = zip(*np)
+ # FAR TOO SLOW!!!
+
+ # test_primers_descr=[ s.split('<img>')[0] for s in self.test_descr ]
+
+ # count=self.test_model(
+ # n_epoch, model,
+ # test_primers_descr,
+ # nb_per_primer=1, generate_images=False
+ # )
+
+ # with open(f'perf_{n_epoch:04d}.txt', 'w') as f:
+ # for i in range(count.size(0)):
+ # for j in range(count.size(1)):
+ # f.write(f'{count[i,j]}')
+ # f.write(" " if j<count.size(1)-1 else "\n")
- log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
######################################################################
-class TaskWiki103(Task):
- def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
- device = torch.device('cpu')):
+class TaskWiki103(Task):
+ def __init__(
+ self,
+ batch_size,
+ len_min=10,
+ len_max=200,
+ min_freq=100,
+ device=torch.device("cpu"),
+ ):
self.batch_size = batch_size
self.len_min = len_min
self.min_freq = min_freq
self.device = device
- self.tokenizer = torchtext.data.get_tokenizer('basic_english')
- train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
+ self.tokenizer = torchtext.data.get_tokenizer("basic_english")
+ train_iter = torchtext.datasets.WikiText103(split="train", root="./data/nlp/")
# Mostly for debug
- if args.data_size > 0:
+ if args.data_size is not None:
train_iter = itertools.islice(train_iter, args.data_size)
def yield_tokens():
- for l in tqdm.tqdm(train_iter, desc = 'vocab'):
+ for l in tqdm.tqdm(train_iter, desc="vocab"):
yield self.tokenizer(l)
self.vocab = torchtext.vocab.build_vocab_from_iterator(
- yield_tokens(),
- specials = [ '<unk>', '<non>' ],
- min_freq = self.min_freq
+ yield_tokens(), specials=["<unk>", "<nul>"], min_freq=self.min_freq
)
- self.vocab.set_default_index(self.vocab[ '<unk>' ])
+ self.vocab.set_default_index(self.vocab["<unk>"])
+ # makes a tensor from a list of list of tokens
def tensorize(self, s):
a = max(len(x) for x in s)
- return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
+ return torch.tensor([self.vocab(x + ["<nul>"] * (a - len(x))) for x in s])
def yield_batches(self, ds):
- s = [ ]
+ s = []
for l in ds:
q = self.tokenizer(l)
if len(q) >= self.len_min and len(q) <= self.len_max:
- s += [ q ]
+ s += [q]
if len(s) == self.batch_size:
yield self.tensorize(s)
- s = [ ]
+ s = []
if len(s) > 0:
yield self.tensorize(s)
- def batches(self, split = 'train'):
- data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
+ def batches(self, split="train"):
+ data_iter = torchtext.datasets.WikiText103(split=split, root="./data/nlp/")
# Mostly for debug
- if args.data_size > 0:
+ if args.data_size is not None:
data_iter = itertools.islice(data_iter, args.data_size)
- return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
+ return self.yield_batches(tqdm.tqdm(data_iter, desc=f"epoch-{split}"))
def vocabulary_size(self):
return len(self.vocab)
- def produce_results(self, n_epoch, model, nb_tokens = 50):
- file_name = f'result_wiki103_{n_epoch:04d}.txt'
-
- with open(file_name, 'w') as outfile:
- for primer in [
- 'the cat is hunting a',
- 'paris is the capital',
- 'cars are convenient',
- 'the difference between men and women is',
- 'the object was blue all over and green all over it was',
- 'cherries are red and lemons are',
- 'cherries are sweet and lemons are',
- 'two plus three equals',
- 'deep learning is',
- ]:
- t_primer = self.tokenizer(primer)
- t_generated = [ ]
-
- for j in range(nb_tokens):
-
- input = self.tensorize([ t_primer + t_generated ]).to(self.device)
- input = F.pad(input, (0, 1)) # Add the next token, the one to predict
- output = model(input)
- logits = output[0, -1]
- if args.synthesis_sampling:
- dist = torch.distributions.categorical.Categorical(logits = logits)
- t_next = dist.sample()
- else:
- t_next = logits.argmax()
- t_generated.append(self.vocab.lookup_token(t_next))
- if t_generated[-1] == '<non>': break
-
- s = ' '.join(t_generated)
-
- outfile.write(f'<{primer}> {s}\n')
-
- log_string(f'wrote {file_name}')
+ def produce_results(self, n_epoch, model):
+ nb_tokens = 50
+ file_name = f"result_wiki103_{n_epoch:04d}.txt"
+
+ with open(file_name, "w") as outfile:
+ for primer in [
+ "the cat is hunting a",
+ "paris is the capital",
+ "cars are convenient",
+ "the difference between men and women is",
+ "the object was blue all over and green all over it was",
+ "cherries are red and lemons are",
+ "cherries are sweet and lemons are",
+ "two plus three equals",
+ "deep learning is",
+ ]:
+ t_primer = self.tokenizer(primer)
+ t_generated = []
+
+ for j in range(nb_tokens):
+
+ input = self.tensorize([t_primer + t_generated]).to(self.device)
+ input = F.pad(
+ input, (0, 1)
+ ) # Add the next token, the one to predict
+ output = model(input)
+ logits = output[0, -1]
+ if args.deterministic_synthesis:
+ t_next = logits.argmax()
+ else:
+ dist = torch.distributions.categorical.Categorical(
+ logits=logits
+ )
+ t_next = dist.sample()
+ t_generated.append(self.vocab.lookup_token(t_next))
+ if t_generated[-1] == "<nul>":
+ break
+
+ s = " ".join(t_generated)
+
+ outfile.write(f"<{primer}> {s}\n")
+
+ log_string(f"wrote {file_name}")
+
######################################################################
-class TaskMNIST(Task):
- def __init__(self, batch_size, device = torch.device('cpu')):
+class TaskMNIST(Task):
+ def __init__(self, batch_size, device=torch.device("cpu")):
self.device = device
self.batch_size = batch_size
- def batches(self, split = 'train'):
- assert split in { 'train', 'test' }
+ def batches(self, split="train"):
+ assert split in {"train", "test"}
data_set = torchvision.datasets.MNIST(
- root = './data', train = (split == 'train'),
- download = True
+ root="./data", train=(split == "train"), download=True
)
data_input = data_set.data.view(-1, 28 * 28).long()
- if args.data_size >= 0:
- data_input = data_input[:args.data_size]
- for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
+ if args.data_size is not None:
+ data_input = data_input[: args.data_size]
+ for batch in tqdm.tqdm(
+ data_input.split(self.batch_size), desc=f"epoch-{split}"
+ ):
yield batch
def vocabulary_size(self):
return 256
- def produce_results(self, n_epoch, model, nb_samples = 64):
- results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
- image_name = f'result_mnist_{n_epoch:04d}.png'
- torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
- image_name, nrow = 16, pad_value = 0.8)
- log_string(f'wrote {image_name}')
+ def produce_results(self, n_epoch, model):
+ nb_samples = 64
+ results = autoregression(
+ model, self.batch_size, nb_samples, 28 * 28, device=self.device
+ )
+ image_name = f"result_mnist_{n_epoch:04d}.png"
+ torchvision.utils.save_image(
+ 1 - results.reshape(-1, 1, 28, 28) / 255.0,
+ image_name,
+ nrow=16,
+ pad_value=0.8,
+ )
+ log_string(f"wrote {image_name}")
+
######################################################################
-log_string(f'device {device}')
-
-if args.data == 'wiki103':
- nb_epochs_default = 10
- task = TaskWiki103(batch_size = args.batch_size, device = device)
-elif args.data == 'mnist':
- nb_epochs_default = 25
- task = TaskMNIST(batch_size = args.batch_size, device = device)
-elif args.data == 'picoclvr':
- nb_epochs_default = 10
- task = TaskPicoCLVR(batch_size = args.batch_size,
- height = args.picoclvr_height,
- width = args.picoclvr_width,
- nb_colors = args.picoclvr_nb_colors,
- device = device)
+log_string(f"device {device}")
+
+if args.data == "wiki103":
+ task = TaskWiki103(batch_size=args.batch_size, device=device)
+elif args.data in {"mnist", "mnist-debug"}:
+ task = TaskMNIST(batch_size=args.batch_size, device=device)
+elif args.data == "picoclvr":
+ task = TaskPicoCLVR(
+ batch_size=args.batch_size,
+ height=args.picoclvr_height,
+ width=args.picoclvr_width,
+ nb_colors=args.picoclvr_nb_colors,
+ device=device,
+ )
else:
- raise ValueError(f'Unknown dataset {args.data}.')
+ raise ValueError(f"Unknown dataset {args.data}.")
vocabulary_size = task.vocabulary_size()
-log_string(f'vocabulary_size {vocabulary_size}')
+log_string(f"vocabulary_size {vocabulary_size}")
##############################
model = mygpt.MyGPT(
- vocabulary_size = vocabulary_size,
- dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
- nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
+ vocabulary_size=vocabulary_size,
+ dim_model=args.dim_model,
+ dim_keys=args.dim_keys,
+ dim_hidden=args.dim_hidden,
+ nb_heads=args.nb_heads,
+ nb_blocks=args.nb_blocks,
+ dropout=args.dropout,
)
model.to(device)
nb_parameters = sum(p.numel() for p in model.parameters())
-log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
-
-######################################################################
-
-if args.optim == 'sgd':
- optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
-elif args.optim == 'adam':
- optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
-elif args.optim == 'adamw':
- optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
-else:
- raise ValueError(f'Unknown optimizer {args.optim}.')
+log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
######################################################################
nb_epochs_finished = 0
if args.no_checkpoint:
- log_string(f'not trying to load checkpoint.')
+ log_string(f"not trying to load checkpoint.")
else:
try:
- checkpoint = torch.load(args.checkpoint_name, map_location = device)
- nb_epochs_finished = checkpoint['nb_epochs_finished']
- model.load_state_dict(checkpoint['model_state'])
- optimizer.load_state_dict(checkpoint['optimizer_state'])
- log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
+ checkpoint = torch.load(args.checkpoint_name)
+ nb_epochs_finished = checkpoint["nb_epochs_finished"]
+ model.load_state_dict(checkpoint["model_state"])
+ torch.set_rng_state(checkpoint["rng_state"])
+ if torch.cuda.is_available():
+ torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
+ log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
except FileNotFoundError:
- log_string('starting from scratch.')
+ log_string("starting from scratch.")
except:
- log_string('error when loading the checkpoint.')
+ log_string("error when loading the checkpoint.")
exit(1)
######################################################################
-nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
-
token_count = 0
-for input in task.batches(split = 'train'):
- token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
+for input in task.batches(split="train"):
+ token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
token_probas = token_count / token_count.sum()
entropy = -torch.xlogy(token_probas, token_probas).sum()
train_set_perplexity = math.exp(entropy)
-#log_string(f'train set perplexity {train_set_perplexity}')
-for k in range(nb_epochs_finished, nb_epochs):
+for n_epoch in range(nb_epochs_finished, args.nb_epochs):
+
+ if args.learning_rate_end < 0:
+ lr = args.learning_rate
+ else:
+ u = n_epoch / (args.nb_epochs - 1)
+ lr = math.exp(
+ (1 - u) * math.log(args.learning_rate)
+ + u * math.log(args.learning_rate_end)
+ )
+ log_string(f"learning_rate {lr}")
+
+ if args.optim == "sgd":
+ optimizer = torch.optim.SGD(model.parameters(), lr=lr)
+ elif args.optim == "adam":
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
+ elif args.optim == "adamw":
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
+ else:
+ raise ValueError(f"Unknown optimizer {args.optim}.")
model.train()
nb_train_samples, acc_train_loss = 0, 0.0
- for input in task.batches(split = 'train'):
+ for input in task.batches(split="train"):
input = input.to(device)
output = model(input)
loss = F.cross_entropy(output.transpose(1, 2), input)
nb_test_samples, acc_test_loss = 0, 0.0
- for input in task.batches(split = 'test'):
+ for input in task.batches(split="test"):
input = input.to(device)
output = model(input)
loss = F.cross_entropy(output.transpose(1, 2), input)
acc_test_loss += loss.item() * input.size(0)
nb_test_samples += input.size(0)
- train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
- test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
+ train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+ test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
- log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
+ log_string(
+ f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
+ )
- task.produce_results(k, model)
+ task.produce_results(n_epoch, model)
checkpoint = {
- 'nb_epochs_finished': k + 1,
- 'model_state': model.state_dict(),
- 'optimizer_state': optimizer.state_dict()
+ "nb_epochs_finished": n_epoch + 1,
+ "model_state": model.state_dict(),
+ "rng_state": torch.get_rng_state(),
}
+ if torch.cuda.is_available():
+ checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
+
torch.save(checkpoint, args.checkpoint_name)
######################################################################