From 199f3195388af8be1f3e50dec343964f73fc0e6d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 21 Jan 2023 16:32:53 +0100 Subject: [PATCH] Added default configurations and reformated with black. --- main.py | 530 ++++++++++++++++++++++++++++------------------------ mygpt.py | 99 ++++++---- picoclvr.py | 348 ++++++++++++++++++++++------------ 3 files changed, 579 insertions(+), 398 deletions(-) diff --git a/main.py b/main.py index aa1b517..f7d03cf 100755 --- a/main.py +++ b/main.py @@ -15,111 +15,138 @@ import mygpt ###################################################################### -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ###################################################################### -parser = argparse.ArgumentParser(description = 'My own GPT.') +parser = argparse.ArgumentParser(description="My own GPT.") -parser.add_argument('--log_filename', - type = str, default = 'train.log') +parser.add_argument("--log_filename", type=str, default="train.log") -parser.add_argument('--seed', - type = int, default = 0) +parser.add_argument("--seed", type=int, default=0) -parser.add_argument('--nb_epochs', - type = int, default = -1) +parser.add_argument("--nb_epochs", type=int, default=None) -parser.add_argument('--batch_size', - type = int, default = 25) +parser.add_argument("--batch_size", type=int, default=25) -parser.add_argument('--data', - type = str, default = 'wiki103') +parser.add_argument("--data", type=str, default="wiki103") -parser.add_argument('--data_size', - type = int, default = -1) +parser.add_argument("--data_size", type=int, default=None) -parser.add_argument('--optim', - type = str, default = 'adam') +parser.add_argument("--optim", type=str, default="adam") -parser.add_argument('--learning_rate', - type = float, default = 1e-3) +parser.add_argument("--learning_rate", type=float, default=1e-3) -parser.add_argument('--learning_rate_end', - type = float, default = 1e-6) +parser.add_argument("--learning_rate_end", type=float, default=1e-6) -parser.add_argument('--dim_model', - type = int, default = 512) +parser.add_argument("--dim_model", type=int, default=None) -parser.add_argument('--dim_keys', - type = int, default = 64) +parser.add_argument("--dim_keys", type=int, default=None) -parser.add_argument('--dim_hidden', - type = int, default = 2048) +parser.add_argument("--dim_hidden", type=int, default=None) -parser.add_argument('--nb_heads', - type = int, default = 8) +parser.add_argument("--nb_heads", type=int, default=None) -parser.add_argument('--nb_blocks', - type = int, default = 12) +parser.add_argument("--nb_blocks", type=int, default=None) -parser.add_argument('--dropout', - type = float, default = 0.1) +parser.add_argument("--dropout", type=float, default=0.1) -parser.add_argument('--deterministic_synthesis', - action='store_true', default = False) +parser.add_argument("--deterministic_synthesis", action="store_true", default=False) -parser.add_argument('--no_checkpoint', - action='store_true', default = False) +parser.add_argument("--no_checkpoint", action="store_true", default=False) -parser.add_argument('--checkpoint_name', - type = str, default = 'checkpoint.pth') +parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") ############################## # picoclvr options -parser.add_argument('--picoclvr_nb_colors', - type = int, default = 5) +parser.add_argument("--picoclvr_nb_colors", type=int, default=5) -parser.add_argument('--picoclvr_height', - type = int, default = 12) +parser.add_argument("--picoclvr_height", type=int, default=12) -parser.add_argument('--picoclvr_width', - type = int, default = 16) +parser.add_argument("--picoclvr_width", type=int, default=16) ###################################################################### args = parser.parse_args() -log_file = open(args.log_filename, 'w') +log_file = open(args.log_filename, "w") if args.seed >= 0: torch.manual_seed(args.seed) ###################################################################### + def log_string(s): - t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime()) + t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime()) if log_file is not None: - log_file.write(t + s + '\n') + log_file.write(t + s + "\n") log_file.flush() print(t + s) sys.stdout.flush() + for n in vars(args): - log_string(f'args.{n} {getattr(args, n)}') + log_string(f"args.{n} {getattr(args, n)}") ###################################################################### +default_args = { + "mnist": { + "nb_epochs": 10, + "dim_model": 64, + "dim_keys": 64, + "dim_hidden": 128, + "nb_heads": 4, + "nb_blocks": 6, + }, + "mnist-debug": { + "nb_epochs": 2, + "data_size": 10000, + "dim_model": 8, + "dim_keys": 8, + "dim_hidden": 8, + "nb_heads": 2, + "nb_blocks": 4, + }, + "wiki103": { + "nb_epochs": 25, + "dim_model": 512, + "dim_keys": 64, + "dim_hidden": 2048, + "nb_heads": 8, + "nb_blocks": 12, + }, + "picoclvr": { + "nb_epochs": 25, + "dim_model": 512, + "dim_keys": 64, + "dim_hidden": 2048, + "nb_heads": 8, + "nb_blocks": 12, + }, +} + +if args.data in default_args: + for k, v in default_args[args.data].items(): + if getattr(args, k) is None: + setattr(args, k, v) + +###################################################################### + + def autoregression( - model, batch_size, - nb_samples, nb_tokens_to_generate, primer = None, - device = torch.device('cpu') + model, + batch_size, + nb_samples, + nb_tokens_to_generate, + primer=None, + device=torch.device("cpu"), ): results = torch.zeros( - nb_samples, nb_tokens_to_generate, - dtype = torch.int64, device = device + nb_samples, nb_tokens_to_generate, dtype=torch.int64, device=device ) if primer is None: @@ -135,16 +162,18 @@ def autoregression( if args.deterministic_synthesis: t_next = logits.argmax(1) else: - dist = torch.distributions.categorical.Categorical(logits = logits) + dist = torch.distributions.categorical.Categorical(logits=logits) t_next = dist.sample() input[:, s] = t_next return results + ###################################################################### + class Task: - def batches(self, split = 'train'): + def batches(self, split="train"): pass def vocabulary_size(self): @@ -153,130 +182,127 @@ class Task: def produce_results(self, n_epoch, model): pass + ###################################################################### import picoclvr + class TaskPicoCLVR(Task): # Make a tensor from a list of strings def tensorize(self, descr): - token_descr = [ s.strip().split(' ') for s in descr ] - l = max([ len(s) for s in token_descr ]) - #token_descr = [ [ '' ] * (l - len(s)) + s for s in token_descr ] - token_descr = [ s + [ '' ] * (l - len(s)) for s in token_descr ] - id_descr = [ [ self.token2id[u] for u in s ] for s in token_descr ] - return torch.tensor(id_descr, device = self.device) - - def trim(self, x, token = ''): + token_descr = [s.strip().split(" ") for s in descr] + l = max([len(s) for s in token_descr]) + padded_token_descr = [s + [""] * (l - len(s)) for s in token_descr] + id_descr = [[self.token2id[u] for u in s] for s in padded_token_descr] + return torch.tensor(id_descr, device=self.device) + + def trim(self, x, token=""): n = self.token2id[token] - i = (1 - (F.pad(x, (1, 1), value = n) == n).min(0).values.long()).cumsum(0) + i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0) a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min() return x[:, a:b] - def __init__(self, batch_size, - height, width, nb_colors = 5, - device = torch.device('cpu')): - + def __init__( + self, batch_size, height, width, nb_colors=5, device=torch.device("cpu") + ): def generate_descr(nb): return picoclvr.generate( - nb, - height = self.height, width = self.width, - nb_colors = nb_colors + nb, height=self.height, width=self.width, nb_colors=nb_colors ) self.height = height self.width = width self.batch_size = batch_size self.device = device - nb = args.data_size if args.data_size > 0 else 250000 + nb = args.data_size if args.data_size is not None else 250000 - log_string(f'generating {nb} samples (can take some time)') + log_string(f"generating {nb} samples (can take some time)") self.train_descr = generate_descr((nb * 4) // 5) self.test_descr = generate_descr((nb * 1) // 5) # Build the tokenizer - tokens = { '' } - for d in [ self.train_descr, self.test_descr ]: + tokens = {""} + for d in [self.train_descr, self.test_descr]: for s in d: - for t in s.strip().split(' '): tokens.add(t) - self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ]) - self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ]) + for t in s.strip().split(" "): + tokens.add(t) + self.token2id = dict([(t, n) for n, t in enumerate(tokens)]) + self.id2token = dict([(n, t) for n, t in enumerate(tokens)]) # Tokenize the train and test sets self.train_input = self.tensorize(self.train_descr) self.test_input = self.tensorize(self.test_descr) - def batches(self, split = 'train'): - assert split in { 'train', 'test' } - input = self.train_input if split == 'train' else self.test_input - for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'): + def batches(self, split="train"): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + for batch in tqdm.tqdm(input.split(self.batch_size), desc=f"epoch-{split}"): yield self.trim(batch) def vocabulary_size(self): return len(self.token2id) - def test_model(self, n_epoch, model, primers_descr, nb_per_primer=1, generate_images=False): + def test_model( + self, n_epoch, model, primers_descr, nb_per_primer=1, generate_images=False + ): nb_tokens_to_generate = self.height * self.width + 3 - result_descr = [ ] + result_descr = [] for primer_descr in primers_descr: results = autoregression( model, self.batch_size, - nb_samples = nb_per_primer, - nb_tokens_to_generate = nb_tokens_to_generate, - primer = self.tensorize([ primer_descr ]).expand(nb_per_primer, -1), - device = self.device + nb_samples=nb_per_primer, + nb_tokens_to_generate=nb_tokens_to_generate, + primer=self.tensorize([primer_descr]).expand(nb_per_primer, -1), + device=self.device, ) - l = [ ' '.join([ self.id2token[t.item()] for t in r ]) for r in results ] + l = [" ".join([self.id2token[t.item()] for t in r]) for r in results] result_descr += l - np = picoclvr.nb_properties( - result_descr, - height = self.height, width = self.width - ) + np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width) nb_requested_properties, _, nb_missing_properties = zip(*np) - log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}') + log_string( + f"nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}" + ) - np=torch.tensor(np) - count=torch.empty(np[:,0].max()+1,np[:,2].max()+1,dtype=torch.int64) + np = torch.tensor(np) + count = torch.empty(np[:, 0].max() + 1, np[:, 2].max() + 1, dtype=torch.int64) for i in range(count.size(0)): for j in range(count.size(1)): - count[i,j]=((np[:,0]==i).long()*(np[:,2]==j).long()).sum() + count[i, j] = ((np[:, 0] == i).long() * (np[:, 2] == j).long()).sum() if generate_images: img = [ - picoclvr.descr2img(d, height = self.height, width = self.width) + picoclvr.descr2img(d, height=self.height, width=self.width) for d in result_descr ] img = torch.cat(img, 0) - image_name = f'result_picoclvr_{n_epoch:04d}.png' + image_name = f"result_picoclvr_{n_epoch:04d}.png" torchvision.utils.save_image( - img / 255., - image_name, nrow = nb_per_primer, pad_value = 0.8 + img / 255.0, image_name, nrow=nb_per_primer, pad_value=0.8 ) - log_string(f'wrote {image_name}') + log_string(f"wrote {image_name}") return count def produce_results(self, n_epoch, model): primers_descr = [ - 'red above green green top blue right of red ', - 'there is red there is yellow there is blue ', - 'red below yellow yellow below green green below blue red right yellow left green right blue left ', - 'green bottom yellow bottom green left of blue yellow right of blue blue top ', + "red above green green top blue right of red ", + "there is red there is yellow there is blue ", + "red below yellow yellow below green green below blue red right yellow left green right blue left ", + "green bottom yellow bottom green left of blue yellow right of blue blue top ", ] self.test_model( - n_epoch, model, - primers_descr, - nb_per_primer=8, generate_images=True + n_epoch, model, primers_descr, nb_per_primer=8, generate_images=True ) # FAR TOO SLOW!!! @@ -284,23 +310,30 @@ class TaskPicoCLVR(Task): # test_primers_descr=[ s.split('')[0] for s in self.test_descr ] # count=self.test_model( - # n_epoch, model, - # test_primers_descr, - # nb_per_primer=1, generate_images=False + # n_epoch, model, + # test_primers_descr, + # nb_per_primer=1, generate_images=False # ) # with open(f'perf_{n_epoch:04d}.txt', 'w') as f: - # for i in range(count.size(0)): - # for j in range(count.size(1)): - # f.write(f'{count[i,j]}') - # f.write(" " if j 0: + if args.data_size is not None: train_iter = itertools.islice(train_iter, args.data_size) def yield_tokens(): - for l in tqdm.tqdm(train_iter, desc = 'vocab'): + for l in tqdm.tqdm(train_iter, desc="vocab"): yield self.tokenizer(l) self.vocab = torchtext.vocab.build_vocab_from_iterator( - yield_tokens(), - specials = [ '', '' ], - min_freq = self.min_freq + yield_tokens(), specials=["", ""], min_freq=self.min_freq ) - self.vocab.set_default_index(self.vocab[ '' ]) + self.vocab.set_default_index(self.vocab[""]) # makes a tensor from a list of list of tokens def tensorize(self, s): a = max(len(x) for x in s) - return torch.tensor([ self.vocab(x + [ '' ] * (a - len(x))) for x in s ]) + return torch.tensor([self.vocab(x + [""] * (a - len(x))) for x in s]) def yield_batches(self, ds): - s = [ ] + s = [] for l in ds: q = self.tokenizer(l) if len(q) >= self.len_min and len(q) <= self.len_max: - s += [ q ] + s += [q] if len(s) == self.batch_size: yield self.tensorize(s) - s = [ ] + s = [] if len(s) > 0: yield self.tensorize(s) - def batches(self, split = 'train'): - data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/') + def batches(self, split="train"): + data_iter = torchtext.datasets.WikiText103(split=split, root="./data/nlp/") # Mostly for debug - if args.data_size > 0: + if args.data_size is not None: data_iter = itertools.islice(data_iter, args.data_size) - return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}')) + return self.yield_batches(tqdm.tqdm(data_iter, desc=f"epoch-{split}")) def vocabulary_size(self): return len(self.vocab) def produce_results(self, n_epoch, model): nb_tokens = 50 - file_name = f'result_wiki103_{n_epoch:04d}.txt' - - with open(file_name, 'w') as outfile: - for primer in [ - 'the cat is hunting a', - 'paris is the capital', - 'cars are convenient', - 'the difference between men and women is', - 'the object was blue all over and green all over it was', - 'cherries are red and lemons are', - 'cherries are sweet and lemons are', - 'two plus three equals', - 'deep learning is', - ]: - t_primer = self.tokenizer(primer) - t_generated = [ ] - - for j in range(nb_tokens): - - input = self.tensorize([ t_primer + t_generated ]).to(self.device) - input = F.pad(input, (0, 1)) # Add the next token, the one to predict - output = model(input) - logits = output[0, -1] - if args.deterministic_synthesis: - t_next = logits.argmax() - else: - dist = torch.distributions.categorical.Categorical(logits = logits) - t_next = dist.sample() - t_generated.append(self.vocab.lookup_token(t_next)) - if t_generated[-1] == '': break - - s = ' '.join(t_generated) - - outfile.write(f'<{primer}> {s}\n') - - log_string(f'wrote {file_name}') + file_name = f"result_wiki103_{n_epoch:04d}.txt" + + with open(file_name, "w") as outfile: + for primer in [ + "the cat is hunting a", + "paris is the capital", + "cars are convenient", + "the difference between men and women is", + "the object was blue all over and green all over it was", + "cherries are red and lemons are", + "cherries are sweet and lemons are", + "two plus three equals", + "deep learning is", + ]: + t_primer = self.tokenizer(primer) + t_generated = [] + + for j in range(nb_tokens): + + input = self.tensorize([t_primer + t_generated]).to(self.device) + input = F.pad( + input, (0, 1) + ) # Add the next token, the one to predict + output = model(input) + logits = output[0, -1] + if args.deterministic_synthesis: + t_next = logits.argmax() + else: + dist = torch.distributions.categorical.Categorical( + logits=logits + ) + t_next = dist.sample() + t_generated.append(self.vocab.lookup_token(t_next)) + if t_generated[-1] == "": + break + + s = " ".join(t_generated) + + outfile.write(f"<{primer}> {s}\n") + + log_string(f"wrote {file_name}") + ###################################################################### -class TaskMNIST(Task): - def __init__(self, batch_size, device = torch.device('cpu')): +class TaskMNIST(Task): + def __init__(self, batch_size, device=torch.device("cpu")): self.device = device self.batch_size = batch_size - def batches(self, split = 'train'): - assert split in { 'train', 'test' } + def batches(self, split="train"): + assert split in {"train", "test"} data_set = torchvision.datasets.MNIST( - root = './data', train = (split == 'train'), - download = True + root="./data", train=(split == "train"), download=True ) data_input = data_set.data.view(-1, 28 * 28).long() - if args.data_size >= 0: - data_input = data_input[:args.data_size] - for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'): + if args.data_size is not None: + data_input = data_input[: args.data_size] + for batch in tqdm.tqdm( + data_input.split(self.batch_size), desc=f"epoch-{split}" + ): yield batch def vocabulary_size(self): @@ -421,108 +459,118 @@ class TaskMNIST(Task): def produce_results(self, n_epoch, model): nb_samples = 64 - results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device) - image_name = f'result_mnist_{n_epoch:04d}.png' - torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255., - image_name, nrow = 16, pad_value = 0.8) - log_string(f'wrote {image_name}') + results = autoregression( + model, self.batch_size, nb_samples, 28 * 28, device=self.device + ) + image_name = f"result_mnist_{n_epoch:04d}.png" + torchvision.utils.save_image( + 1 - results.reshape(-1, 1, 28, 28) / 255.0, + image_name, + nrow=16, + pad_value=0.8, + ) + log_string(f"wrote {image_name}") + ###################################################################### -log_string(f'device {device}') - -if args.data == 'wiki103': - nb_epochs_default = 10 - task = TaskWiki103(batch_size = args.batch_size, device = device) -elif args.data == 'mnist': - nb_epochs_default = 25 - task = TaskMNIST(batch_size = args.batch_size, device = device) -elif args.data == 'picoclvr': - nb_epochs_default = 10 - task = TaskPicoCLVR(batch_size = args.batch_size, - height = args.picoclvr_height, - width = args.picoclvr_width, - nb_colors = args.picoclvr_nb_colors, - device = device) +log_string(f"device {device}") + +if args.data == "wiki103": + task = TaskWiki103(batch_size=args.batch_size, device=device) +elif args.data in {"mnist", "mnist-debug"}: + task = TaskMNIST(batch_size=args.batch_size, device=device) +elif args.data == "picoclvr": + task = TaskPicoCLVR( + batch_size=args.batch_size, + height=args.picoclvr_height, + width=args.picoclvr_width, + nb_colors=args.picoclvr_nb_colors, + device=device, + ) else: - raise ValueError(f'Unknown dataset {args.data}.') + raise ValueError(f"Unknown dataset {args.data}.") vocabulary_size = task.vocabulary_size() -log_string(f'vocabulary_size {vocabulary_size}') +log_string(f"vocabulary_size {vocabulary_size}") ############################## model = mygpt.MyGPT( - vocabulary_size = vocabulary_size, - dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden, - nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout + vocabulary_size=vocabulary_size, + dim_model=args.dim_model, + dim_keys=args.dim_keys, + dim_hidden=args.dim_hidden, + nb_heads=args.nb_heads, + nb_blocks=args.nb_blocks, + dropout=args.dropout, ) model.to(device) nb_parameters = sum(p.numel() for p in model.parameters()) -log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)') +log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") ###################################################################### nb_epochs_finished = 0 if args.no_checkpoint: - log_string(f'not trying to load checkpoint.') + log_string(f"not trying to load checkpoint.") else: try: checkpoint = torch.load(args.checkpoint_name) - nb_epochs_finished = checkpoint['nb_epochs_finished'] - model.load_state_dict(checkpoint['model_state']) - torch.set_rng_state(checkpoint['rng_state']) + nb_epochs_finished = checkpoint["nb_epochs_finished"] + model.load_state_dict(checkpoint["model_state"]) + torch.set_rng_state(checkpoint["rng_state"]) if torch.cuda.is_available(): - torch.cuda.set_rng_state(checkpoint['cuda_rng_state']) - log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.') + torch.cuda.set_rng_state(checkpoint["cuda_rng_state"]) + log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.") except FileNotFoundError: - log_string('starting from scratch.') + log_string("starting from scratch.") except: - log_string('error when loading the checkpoint.') + log_string("error when loading the checkpoint.") exit(1) ###################################################################### -nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default - token_count = 0 -for input in task.batches(split = 'train'): - token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1)) +for input in task.batches(split="train"): + token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1)) token_probas = token_count / token_count.sum() entropy = -torch.xlogy(token_probas, token_probas).sum() train_set_perplexity = math.exp(entropy) -for n_epoch in range(nb_epochs_finished, nb_epochs): +for n_epoch in range(nb_epochs_finished, args.nb_epochs): if args.learning_rate_end < 0: lr = args.learning_rate else: - u = n_epoch / (nb_epochs - 1) - lr = math.exp((1 - u) * math.log(args.learning_rate) + - u * math.log(args.learning_rate_end)) - log_string(f'learning_rate {lr}') - - if args.optim == 'sgd': - optimizer = torch.optim.SGD(model.parameters(), lr = lr) - elif args.optim == 'adam': - optimizer = torch.optim.Adam(model.parameters(), lr = lr) - elif args.optim == 'adamw': - optimizer = torch.optim.AdamW(model.parameters(), lr = lr) + u = n_epoch / (args.nb_epochs - 1) + lr = math.exp( + (1 - u) * math.log(args.learning_rate) + + u * math.log(args.learning_rate_end) + ) + log_string(f"learning_rate {lr}") + + if args.optim == "sgd": + optimizer = torch.optim.SGD(model.parameters(), lr=lr) + elif args.optim == "adam": + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + elif args.optim == "adamw": + optimizer = torch.optim.AdamW(model.parameters(), lr=lr) else: - raise ValueError(f'Unknown optimizer {args.optim}.') + raise ValueError(f"Unknown optimizer {args.optim}.") model.train() nb_train_samples, acc_train_loss = 0, 0.0 - for input in task.batches(split = 'train'): + for input in task.batches(split="train"): input = input.to(device) output = model(input) loss = F.cross_entropy(output.transpose(1, 2), input) @@ -539,28 +587,30 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): nb_test_samples, acc_test_loss = 0, 0.0 - for input in task.batches(split = 'test'): + for input in task.batches(split="test"): input = input.to(device) output = model(input) loss = F.cross_entropy(output.transpose(1, 2), input) acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0) - train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples)) - test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples)) + train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) + test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) - log_string(f'perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}') + log_string( + f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}" + ) task.produce_results(n_epoch, model) checkpoint = { - 'nb_epochs_finished': n_epoch + 1, - 'model_state': model.state_dict(), - 'rng_state': torch.get_rng_state(), + "nb_epochs_finished": n_epoch + 1, + "model_state": model.state_dict(), + "rng_state": torch.get_rng_state(), } if torch.cuda.is_available(): - checkpoint['cuda_rng_state'] = torch.cuda.get_rng_state() + checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state() torch.save(checkpoint, args.checkpoint_name) diff --git a/mygpt.py b/mygpt.py index f954797..a6b257c 100755 --- a/mygpt.py +++ b/mygpt.py @@ -14,6 +14,7 @@ from torch.nn import functional as F ############################## + class WithResidual(nn.Module): def __init__(self, *f): super().__init__() @@ -22,8 +23,10 @@ class WithResidual(nn.Module): def forward(self, x): return x + self.f(x) + ############################## + class AddPositionalEncoding(nn.Module): def __init__(self, len_max): super().__init__() @@ -31,18 +34,20 @@ class AddPositionalEncoding(nn.Module): # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D})) def forward(self, x): - t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None] - j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :] - k = j%2 - pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k) + t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None] + j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :] + k = j % 2 + pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k) return x + pe + ############################## + class QKVAttention(nn.Module): - def __init__(self, - dim_in, dim_qk, dim_v, - nb_heads = 1, causal = False, attention_dropout = 0.0): + def __init__( + self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0 + ): super().__init__() def randw(*d): @@ -56,36 +61,47 @@ class QKVAttention(nn.Module): self.w_v = randw(nb_heads, dim_v, dim_in) self.w_o = randw(dim_v * nb_heads, dim_in) - def forward(self, x_q, x_kv = None): - if x_kv is None: x_kv = x_q + def forward(self, x_q, x_kv=None): + if x_kv is None: + x_kv = x_q - q = torch.einsum('ntc,hdc->nhtd', x_q, self.w_q) - k = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_k) - v = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_v) + q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q) + k = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_k) + v = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_v) - a = torch.einsum('nhtd,nhsd->nhts', q, k) / math.sqrt(q.size(3)) + a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3)) if self.causal: - forbidden_attention = torch.arange(a.size(2), device = q.device)[None, None, :, None] \ - < torch.arange(a.size(3), device = q.device)[None, None, None, :] - a = a.masked_fill(forbidden_attention, float('-inf')) + forbidden_attention = ( + torch.arange(a.size(2), device=q.device)[None, None, :, None] + < torch.arange(a.size(3), device=q.device)[None, None, None, :] + ) + a = a.masked_fill(forbidden_attention, float("-inf")) - a = a.softmax(dim = 3) + a = a.softmax(dim=3) a = F.dropout(a, self.attention_dropout, self.training) - y = torch.einsum('nhts,nhsd->nthd', a, v).flatten(2) + y = torch.einsum("nhts,nhsd->nthd", a, v).flatten(2) y = y @ self.w_o return y + ############################## + class MyGPT(nn.Module): - def __init__(self, - vocabulary_size, - dim_model, dim_keys, dim_hidden, - nb_heads, nb_blocks, - dropout = 0.0, len_max = 1e5): + def __init__( + self, + vocabulary_size, + dim_model, + dim_keys, + dim_hidden, + nb_heads, + nb_blocks, + dropout=0.0, + len_max=1e5, + ): super().__init__() @@ -97,37 +113,38 @@ class MyGPT(nn.Module): AddPositionalEncoding(len_max), ) - trunk_blocks = [ ] + trunk_blocks = [] for _ in range(nb_blocks): trunk_blocks += [ WithResidual( nn.LayerNorm((dim_model,)), QKVAttention( - dim_in = dim_model, - dim_qk = dim_keys, - dim_v = dim_model // nb_heads, - nb_heads = nb_heads, - causal = True, attention_dropout = dropout + dim_in=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + causal=True, + attention_dropout=dropout, ), ), WithResidual( nn.LayerNorm((dim_model,)), - nn.Linear(in_features = dim_model, out_features = dim_hidden), + nn.Linear(in_features=dim_model, out_features=dim_hidden), nn.ReLU(), - nn.Linear(in_features = dim_hidden, out_features = dim_model), + nn.Linear(in_features=dim_hidden, out_features=dim_model), nn.Dropout(dropout), ), ] self.trunk = nn.Sequential(*trunk_blocks) - self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size) + self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size) with torch.no_grad(): for m in self.modules(): if isinstance(m, nn.Embedding): - m.weight.normal_(mean = 0, std = 2e-2) + m.weight.normal_(mean=0, std=2e-2) elif isinstance(m, nn.LayerNorm): m.bias.zero_() m.weight.fill_(1.0) @@ -139,19 +156,23 @@ class MyGPT(nn.Module): x = self.readout(x) return x + ###################################################################### -if __name__ == '__main__': - print('Basic check.') +if __name__ == "__main__": + print("Basic check.") vocabulary_size = 10 x = torch.randint(vocabulary_size, (25, 100)) model = MyGPT( - vocabulary_size = vocabulary_size, - dim_model = 18, dim_keys = 50, dim_hidden = 100, - nb_heads = 2, nb_blocks = 3, - dropout = 0.1 + vocabulary_size=vocabulary_size, + dim_model=18, + dim_keys=50, + dim_hidden=100, + nb_heads=2, + nb_blocks=3, + dropout=0.1, ) y = model(x) diff --git a/picoclvr.py b/picoclvr.py index 059e352..fb791fe 100755 --- a/picoclvr.py +++ b/picoclvr.py @@ -7,100 +7,194 @@ import torch, torchvision -colors = [ - [ 255, 255, 255 ], [ 255, 0, 0 ], [ 0, 128, 0 ], [ 0, 0, 255 ], [ 255, 255, 0 ], - [ 0, 0, 0 ], [ 128, 0, 0 ], [ 139, 0, 0 ], [ 165, 42, 42 ], [ 178, 34, 34 ], - [ 220, 20, 60 ], [ 255, 99, 71 ], [ 255, 127, 80 ], [ 205, 92, 92 ], [ 240, 128, 128 ], - [ 233, 150, 122 ], [ 250, 128, 114 ], [ 255, 160, 122 ], [ 255, 69, 0 ], [ 255, 140, 0 ], - [ 255, 165, 0 ], [ 255, 215, 0 ], [ 184, 134, 11 ], [ 218, 165, 32 ], [ 238, 232, 170 ], - [ 189, 183, 107 ], [ 240, 230, 140 ], [ 128, 128, 0 ], [ 154, 205, 50 ], [ 85, 107, 47 ], - [ 107, 142, 35 ], [ 124, 252, 0 ], [ 127, 255, 0 ], [ 173, 255, 47 ], [ 0, 100, 0 ], - [ 34, 139, 34 ], [ 0, 255, 0 ], [ 50, 205, 50 ], [ 144, 238, 144 ], [ 152, 251, 152 ], - [ 143, 188, 143 ], [ 0, 250, 154 ], [ 0, 255, 127 ], [ 46, 139, 87 ], [ 102, 205, 170 ], - [ 60, 179, 113 ], [ 32, 178, 170 ], [ 47, 79, 79 ], [ 0, 128, 128 ], [ 0, 139, 139 ], - [ 0, 255, 255 ], [ 0, 255, 255 ], [ 224, 255, 255 ], [ 0, 206, 209 ], [ 64, 224, 208 ], - [ 72, 209, 204 ], [ 175, 238, 238 ], [ 127, 255, 212 ], [ 176, 224, 230 ], [ 95, 158, 160 ], - [ 70, 130, 180 ], [ 100, 149, 237 ], [ 0, 191, 255 ], [ 30, 144, 255 ], [ 173, 216, 230 ], - [ 135, 206, 235 ], [ 135, 206, 250 ], [ 25, 25, 112 ], [ 0, 0, 128 ], [ 0, 0, 139 ], - [ 0, 0, 205 ], [ 65, 105, 225 ], [ 138, 43, 226 ], [ 75, 0, 130 ], [ 72, 61, 139 ], - [ 106, 90, 205 ], [ 123, 104, 238 ], [ 147, 112, 219 ], [ 139, 0, 139 ], [ 148, 0, 211 ], - [ 153, 50, 204 ], [ 186, 85, 211 ], [ 128, 0, 128 ], [ 216, 191, 216 ], [ 221, 160, 221 ], - [ 238, 130, 238 ], [ 255, 0, 255 ], [ 218, 112, 214 ], [ 199, 21, 133 ], [ 219, 112, 147 ], - [ 255, 20, 147 ], [ 255, 105, 180 ], [ 255, 182, 193 ], [ 255, 192, 203 ], [ 250, 235, 215 ], - [ 245, 245, 220 ], [ 255, 228, 196 ], [ 255, 235, 205 ], [ 245, 222, 179 ], [ 255, 248, 220 ], - [ 255, 250, 205 ], [ 250, 250, 210 ], [ 255, 255, 224 ], [ 139, 69, 19 ], [ 160, 82, 45 ], - [ 210, 105, 30 ], [ 205, 133, 63 ], [ 244, 164, 96 ], [ 222, 184, 135 ], [ 210, 180, 140 ], - [ 188, 143, 143 ], [ 255, 228, 181 ], [ 255, 222, 173 ], [ 255, 218, 185 ], [ 255, 228, 225 ], - [ 255, 240, 245 ], [ 250, 240, 230 ], [ 253, 245, 230 ], [ 255, 239, 213 ], [ 255, 245, 238 ], - [ 245, 255, 250 ], [ 112, 128, 144 ], [ 119, 136, 153 ], [ 176, 196, 222 ], [ 230, 230, 250 ], - [ 255, 250, 240 ], [ 240, 248, 255 ], [ 248, 248, 255 ], [ 240, 255, 240 ], [ 255, 255, 240 ], - [ 240, 255, 255 ], [ 255, 250, 250 ], [ 192, 192, 192 ], [ 220, 220, 220 ], [ 245, 245, 245 ], -] - -color_names = [ - 'white', 'red', 'green', 'blue', 'yellow', - 'black', 'maroon', 'dark_red', 'brown', 'firebrick', - 'crimson', 'tomato', 'coral', 'indian_red', 'light_coral', - 'dark_salmon', 'salmon', 'light_salmon', 'orange_red', 'dark_orange', - 'orange', 'gold', 'dark_golden_rod', 'golden_rod', 'pale_golden_rod', - 'dark_khaki', 'khaki', 'olive', 'yellow_green', 'dark_olive_green', - 'olive_drab', 'lawn_green', 'chartreuse', 'green_yellow', 'dark_green', - 'forest_green', 'lime', 'lime_green', 'light_green', 'pale_green', - 'dark_sea_green', 'medium_spring_green', 'spring_green', 'sea_green', 'medium_aqua_marine', - 'medium_sea_green', 'light_sea_green', 'dark_slate_gray', 'teal', 'dark_cyan', - 'aqua', 'cyan', 'light_cyan', 'dark_turquoise', 'turquoise', - 'medium_turquoise', 'pale_turquoise', 'aqua_marine', 'powder_blue', 'cadet_blue', - 'steel_blue', 'corn_flower_blue', 'deep_sky_blue', 'dodger_blue', 'light_blue', - 'sky_blue', 'light_sky_blue', 'midnight_blue', 'navy', 'dark_blue', - 'medium_blue', 'royal_blue', 'blue_violet', 'indigo', 'dark_slate_blue', - 'slate_blue', 'medium_slate_blue', 'medium_purple', 'dark_magenta', 'dark_violet', - 'dark_orchid', 'medium_orchid', 'purple', 'thistle', 'plum', - 'violet', 'magenta', 'orchid', 'medium_violet_red', 'pale_violet_red', - 'deep_pink', 'hot_pink', 'light_pink', 'pink', 'antique_white', - 'beige', 'bisque', 'blanched_almond', 'wheat', 'corn_silk', - 'lemon_chiffon', 'light_golden_rod_yellow', 'light_yellow', 'saddle_brown', 'sienna', - 'chocolate', 'peru', 'sandy_brown', 'burly_wood', 'tan', - 'rosy_brown', 'moccasin', 'navajo_white', 'peach_puff', 'misty_rose', - 'lavender_blush', 'linen', 'old_lace', 'papaya_whip', 'sea_shell', - 'mint_cream', 'slate_gray', 'light_slate_gray', 'light_steel_blue', 'lavender', - 'floral_white', 'alice_blue', 'ghost_white', 'honeydew', 'ivory', - 'azure', 'snow', 'silver', 'gainsboro', 'white_smoke', -] - -color_id = dict( [ (n, k) for k, n in enumerate(color_names) ] ) -color_tokens = dict( [ (n, c) for n, c in zip(color_names, colors) ] ) +color_tokens = { + "white": [255, 255, 255], + "red": [255, 0, 0], + "green": [0, 128, 0], + "blue": [0, 0, 255], + "yellow": [255, 255, 0], + "black": [0, 0, 0], + "maroon": [128, 0, 0], + "dark_red": [139, 0, 0], + "brown": [165, 42, 42], + "firebrick": [178, 34, 34], + "crimson": [220, 20, 60], + "tomato": [255, 99, 71], + "coral": [255, 127, 80], + "indian_red": [205, 92, 92], + "light_coral": [240, 128, 128], + "dark_salmon": [233, 150, 122], + "salmon": [250, 128, 114], + "light_salmon": [255, 160, 122], + "orange_red": [255, 69, 0], + "dark_orange": [255, 140, 0], + "orange": [255, 165, 0], + "gold": [255, 215, 0], + "dark_golden_rod": [184, 134, 11], + "golden_rod": [218, 165, 32], + "pale_golden_rod": [238, 232, 170], + "dark_khaki": [189, 183, 107], + "khaki": [240, 230, 140], + "olive": [128, 128, 0], + "yellow_green": [154, 205, 50], + "dark_olive_green": [85, 107, 47], + "olive_drab": [107, 142, 35], + "lawn_green": [124, 252, 0], + "chartreuse": [127, 255, 0], + "green_yellow": [173, 255, 47], + "dark_green": [0, 100, 0], + "forest_green": [34, 139, 34], + "lime": [0, 255, 0], + "lime_green": [50, 205, 50], + "light_green": [144, 238, 144], + "pale_green": [152, 251, 152], + "dark_sea_green": [143, 188, 143], + "medium_spring_green": [0, 250, 154], + "spring_green": [0, 255, 127], + "sea_green": [46, 139, 87], + "medium_aqua_marine": [102, 205, 170], + "medium_sea_green": [60, 179, 113], + "light_sea_green": [32, 178, 170], + "dark_slate_gray": [47, 79, 79], + "teal": [0, 128, 128], + "dark_cyan": [0, 139, 139], + "aqua": [0, 255, 255], + "cyan": [0, 255, 255], + "light_cyan": [224, 255, 255], + "dark_turquoise": [0, 206, 209], + "turquoise": [64, 224, 208], + "medium_turquoise": [72, 209, 204], + "pale_turquoise": [175, 238, 238], + "aqua_marine": [127, 255, 212], + "powder_blue": [176, 224, 230], + "cadet_blue": [95, 158, 160], + "steel_blue": [70, 130, 180], + "corn_flower_blue": [100, 149, 237], + "deep_sky_blue": [0, 191, 255], + "dodger_blue": [30, 144, 255], + "light_blue": [173, 216, 230], + "sky_blue": [135, 206, 235], + "light_sky_blue": [135, 206, 250], + "midnight_blue": [25, 25, 112], + "navy": [0, 0, 128], + "dark_blue": [0, 0, 139], + "medium_blue": [0, 0, 205], + "royal_blue": [65, 105, 225], + "blue_violet": [138, 43, 226], + "indigo": [75, 0, 130], + "dark_slate_blue": [72, 61, 139], + "slate_blue": [106, 90, 205], + "medium_slate_blue": [123, 104, 238], + "medium_purple": [147, 112, 219], + "dark_magenta": [139, 0, 139], + "dark_violet": [148, 0, 211], + "dark_orchid": [153, 50, 204], + "medium_orchid": [186, 85, 211], + "purple": [128, 0, 128], + "thistle": [216, 191, 216], + "plum": [221, 160, 221], + "violet": [238, 130, 238], + "magenta": [255, 0, 255], + "orchid": [218, 112, 214], + "medium_violet_red": [199, 21, 133], + "pale_violet_red": [219, 112, 147], + "deep_pink": [255, 20, 147], + "hot_pink": [255, 105, 180], + "light_pink": [255, 182, 193], + "pink": [255, 192, 203], + "antique_white": [250, 235, 215], + "beige": [245, 245, 220], + "bisque": [255, 228, 196], + "blanched_almond": [255, 235, 205], + "wheat": [245, 222, 179], + "corn_silk": [255, 248, 220], + "lemon_chiffon": [255, 250, 205], + "light_golden_rod_yellow": [250, 250, 210], + "light_yellow": [255, 255, 224], + "saddle_brown": [139, 69, 19], + "sienna": [160, 82, 45], + "chocolate": [210, 105, 30], + "peru": [205, 133, 63], + "sandy_brown": [244, 164, 96], + "burly_wood": [222, 184, 135], + "tan": [210, 180, 140], + "rosy_brown": [188, 143, 143], + "moccasin": [255, 228, 181], + "navajo_white": [255, 222, 173], + "peach_puff": [255, 218, 185], + "misty_rose": [255, 228, 225], + "lavender_blush": [255, 240, 245], + "linen": [250, 240, 230], + "old_lace": [253, 245, 230], + "papaya_whip": [255, 239, 213], + "sea_shell": [255, 245, 238], + "mint_cream": [245, 255, 250], + "slate_gray": [112, 128, 144], + "light_slate_gray": [119, 136, 153], + "light_steel_blue": [176, 196, 222], + "lavender": [230, 230, 250], + "floral_white": [255, 250, 240], + "alice_blue": [240, 248, 255], + "ghost_white": [248, 248, 255], + "honeydew": [240, 255, 240], + "ivory": [255, 255, 240], + "azure": [240, 255, 255], + "snow": [255, 250, 250], + "silver": [192, 192, 192], + "gainsboro": [220, 220, 220], + "white_smoke": [245, 245, 245], +} + +color_id = dict([(n, k) for k, n in enumerate(color_tokens.keys())]) +color_names = dict([(k, n) for k, n in enumerate(color_tokens.keys())]) ###################################################################### -def all_properties(height, width, nb_squares, square_i, square_j, square_c): - s = [ ] - - for r, c in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]: - s += [ f'there is {c}' ] - if square_i[r] >= height - height//3: s += [ f'{c} bottom' ] - if square_i[r] < height//3: s += [ f'{c} top' ] - if square_j[r] >= width - width//3: s += [ f'{c} right' ] - if square_j[r] < width//3: s += [ f'{c} left' ] - - for t, d in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]: - if square_i[r] > square_i[t]: s += [ f'{c} below {d}' ] - if square_i[r] < square_i[t]: s += [ f'{c} above {d}' ] - if square_j[r] > square_j[t]: s += [ f'{c} right of {d}' ] - if square_j[r] < square_j[t]: s += [ f'{c} left of {d}' ] +def all_properties(height, width, nb_squares, square_i, square_j, square_c): + s = [] + + for r, c in [(k, color_names[square_c[k].item()]) for k in range(nb_squares)]: + s += [f"there is {c}"] + + if square_i[r] >= height - height // 3: + s += [f"{c} bottom"] + if square_i[r] < height // 3: + s += [f"{c} top"] + if square_j[r] >= width - width // 3: + s += [f"{c} right"] + if square_j[r] < width // 3: + s += [f"{c} left"] + + for t, d in [(k, color_names[square_c[k].item()]) for k in range(nb_squares)]: + if square_i[r] > square_i[t]: + s += [f"{c} below {d}"] + if square_i[r] < square_i[t]: + s += [f"{c} above {d}"] + if square_j[r] > square_j[t]: + s += [f"{c} right of {d}"] + if square_j[r] < square_j[t]: + s += [f"{c} left of {d}"] return s + ###################################################################### -def generate(nb, height, width, - max_nb_squares = 5, max_nb_properties = 10, - nb_colors = 5, - pruning_criterion = None): + +def generate( + nb, + height, + width, + max_nb_squares=5, + max_nb_properties=10, + nb_colors=5, + pruning_criterion=None, +): assert nb_colors >= max_nb_squares and nb_colors <= len(color_tokens) - 1 - descr = [ ] + descr = [] for n in range(nb): @@ -108,70 +202,77 @@ def generate(nb, height, width, square_position = torch.randperm(height * width)[:nb_squares] # color 0 is white and reserved for the background square_c = torch.randperm(nb_colors)[:nb_squares] + 1 - square_i = square_position.div(width, rounding_mode = 'floor') + square_i = square_position.div(width, rounding_mode="floor") square_j = square_position % width - img = [ 0 ] * height * width - for k in range(nb_squares): img[square_position[k]] = square_c[k] + img = torch.zeros(height * width, dtype=torch.int64) + for k in range(nb_squares): + img[square_position[k]] = square_c[k] # generates all the true properties s = all_properties(height, width, nb_squares, square_i, square_j, square_c) if pruning_criterion is not None: - s = list(filter(pruning_criterion,s)) + s = list(filter(pruning_criterion, s)) # pick at most max_nb_properties at random nb_properties = torch.randint(max_nb_properties, (1,)) + 1 - s = ' '.join([ s[k] for k in torch.randperm(len(s))[:nb_properties] ] ) - s += ' ' + ' '.join([ f'{color_names[n]}' for n in img ]) + s = " ".join([s[k] for k in torch.randperm(len(s))[:nb_properties]]) + s += " " + " ".join([f"{color_names[n.item()]}" for n in img]) - descr += [ s ] + descr += [s] return descr + ###################################################################### + def descr2img(descr, height, width): if type(descr) == list: - return torch.cat([ descr2img(d, height, width) for d in descr ], 0) + return torch.cat([descr2img(d, height, width) for d in descr], 0) def token2color(t): try: return color_tokens[t] except KeyError: - return [ 128, 128, 128 ] + return [128, 128, 128] - d = descr.split('', 1) - d = d[-1] if len(d) > 1 else '' - d = d.strip().split(' ')[:height * width] - d = d + [ '' ] * (height * width - len(d)) - d = [ token2color(t) for t in d ] + d = descr.split("", 1) + d = d[-1] if len(d) > 1 else "" + d = d.strip().split(" ")[: height * width] + d = d + [""] * (height * width - len(d)) + d = [token2color(t) for t in d] img = torch.tensor(d).permute(1, 0) img = img.reshape(1, 3, height, width) return img + ###################################################################### + def descr2properties(descr, height, width): if type(descr) == list: - return [ descr2properties(d, height, width) for d in descr ] + return [descr2properties(d, height, width) for d in descr] - d = descr.split('', 1) - d = d[-1] if len(d) > 1 else '' - d = d.strip().split(' ')[:height * width] + d = descr.split("", 1) + d = d[-1] if len(d) > 1 else "" + d = d.strip().split(" ")[: height * width] seen = {} - if len(d) != height * width: return [] + if len(d) != height * width: + return [] for k, x in enumerate(d): if x != color_names[0]: if x in color_tokens: - if x in seen: return [] + if x in seen: + return [] else: return [] seen[x] = (color_id[x], k // width, k % width) @@ -190,16 +291,19 @@ def descr2properties(descr, height, width): return s + ###################################################################### + def nb_properties(descr, height, width): if type(descr) == list: - return [ nb_properties(d, height, width) for d in descr ] + return [nb_properties(d, height, width) for d in descr] - d = descr.split('', 1) - if len(d) == 0: return 0 - d = d[0].strip().split('') - d = [ x.strip() for x in d ] + d = descr.split("", 1) + if len(d) == 0: + return 0 + d = d[0].strip().split("") + d = [x.strip() for x in d] requested_properties = set(d) all_properties = set(descr2properties(descr, height, width)) @@ -207,30 +311,36 @@ def nb_properties(descr, height, width): return (len(requested_properties), len(all_properties), len(missing_properties)) + ###################################################################### -if __name__ == '__main__': +if __name__ == "__main__": descr = generate( - nb = 5, height = 12, width = 16, - pruning_criterion = lambda s: not ('green' in s and ('right' in s or 'left' in s)) + nb=5, + height=12, + width=16, + pruning_criterion=lambda s: not ( + "green" in s and ("right" in s or "left" in s) + ), ) - print(descr2properties(descr, height = 12, width = 16)) - print(nb_properties(descr, height = 12, width = 16)) + print(descr2properties(descr, height=12, width=16)) + print(nb_properties(descr, height=12, width=16)) - with open('picoclvr_example.txt', 'w') as f: + with open("picoclvr_example.txt", "w") as f: for d in descr: - f.write(f'{d}\n\n') + f.write(f"{d}\n\n") - img = descr2img(descr, height = 12, width = 16) - torchvision.utils.save_image(img / 255., - 'picoclvr_example.png', nrow = 16, pad_value = 0.8) + img = descr2img(descr, height=12, width=16) + torchvision.utils.save_image( + img / 255.0, "picoclvr_example.png", nrow=16, pad_value=0.8 + ) import time start_time = time.perf_counter() - descr = generate(nb = 1000, height = 12, width = 16) + descr = generate(nb=1000, height=12, width=16) end_time = time.perf_counter() - print(f'{len(descr) / (end_time - start_time):.02f} samples per second') + print(f"{len(descr) / (end_time - start_time):.02f} samples per second") ###################################################################### -- 2.39.5