X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=aa1b51799a159b2f176154d63cce5893eabd38ad;hb=62533ba50393866c15b322074cad836684dd69e7;hp=83227bb43a24897b506f0b82035a69dc33f7acfe;hpb=38c69cc69cffd1a54b92bfe993f52aa649afb7d4;p=mygpt.git diff --git a/main.py b/main.py index 83227bb..aa1b517 100755 --- a/main.py +++ b/main.py @@ -42,7 +42,10 @@ parser.add_argument('--optim', type = str, default = 'adam') parser.add_argument('--learning_rate', - type = float, default = 1e-4) + type = float, default = 1e-3) + +parser.add_argument('--learning_rate_end', + type = float, default = 1e-6) parser.add_argument('--dim_model', type = int, default = 512) @@ -62,8 +65,8 @@ parser.add_argument('--nb_blocks', parser.add_argument('--dropout', type = float, default = 0.1) -parser.add_argument('--synthesis_sampling', - action='store_true', default = True) +parser.add_argument('--deterministic_synthesis', + action='store_true', default = False) parser.add_argument('--no_checkpoint', action='store_true', default = False) @@ -129,11 +132,11 @@ def autoregression( for s in range(first, input.size(1)): output = model(input) logits = output[:, s] - if args.synthesis_sampling: + if args.deterministic_synthesis: + t_next = logits.argmax(1) + else: dist = torch.distributions.categorical.Categorical(logits = logits) t_next = dist.sample() - else: - t_next = logits.argmax(1) input[:, s] = t_next return results @@ -188,6 +191,7 @@ class TaskPicoCLVR(Task): self.device = device nb = args.data_size if args.data_size > 0 else 250000 + log_string(f'generating {nb} samples (can take some time)') self.train_descr = generate_descr((nb * 4) // 5) self.test_descr = generate_descr((nb * 1) // 5) @@ -212,17 +216,11 @@ class TaskPicoCLVR(Task): def vocabulary_size(self): return len(self.token2id) - def produce_results(self, n_epoch, model): + def test_model(self, n_epoch, model, primers_descr, nb_per_primer=1, generate_images=False): nb_tokens_to_generate = self.height * self.width + 3 result_descr = [ ] - nb_per_primer = 8 - for primer_descr in [ - 'red above green green top blue right of red ', - 'there is red there is yellow there is blue ', - 'red below yellow yellow below green green below blue red right yellow left green right blue left ', - 'green bottom yellow bottom green left of blue yellow right of blue blue top ', - ]: + for primer_descr in primers_descr: results = autoregression( model, @@ -245,18 +243,57 @@ class TaskPicoCLVR(Task): log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}') - img = [ - picoclvr.descr2img(d, height = self.height, width = self.width) - for d in result_descr + np=torch.tensor(np) + count=torch.empty(np[:,0].max()+1,np[:,2].max()+1,dtype=torch.int64) + for i in range(count.size(0)): + for j in range(count.size(1)): + count[i,j]=((np[:,0]==i).long()*(np[:,2]==j).long()).sum() + + if generate_images: + img = [ + picoclvr.descr2img(d, height = self.height, width = self.width) + for d in result_descr + ] + + img = torch.cat(img, 0) + image_name = f'result_picoclvr_{n_epoch:04d}.png' + torchvision.utils.save_image( + img / 255., + image_name, nrow = nb_per_primer, pad_value = 0.8 + ) + log_string(f'wrote {image_name}') + + return count + + def produce_results(self, n_epoch, model): + primers_descr = [ + 'red above green green top blue right of red ', + 'there is red there is yellow there is blue ', + 'red below yellow yellow below green green below blue red right yellow left green right blue left ', + 'green bottom yellow bottom green left of blue yellow right of blue blue top ', ] - img = torch.cat(img, 0) - image_name = f'result_picoclvr_{n_epoch:04d}.png' - torchvision.utils.save_image( - img / 255., - image_name, nrow = nb_per_primer, pad_value = 0.8 + self.test_model( + n_epoch, model, + primers_descr, + nb_per_primer=8, generate_images=True ) - log_string(f'wrote {image_name}') + + # FAR TOO SLOW!!! + + # test_primers_descr=[ s.split('')[0] for s in self.test_descr ] + + # count=self.test_model( + # n_epoch, model, + # test_primers_descr, + # nb_per_primer=1, generate_images=False + # ) + + # with open(f'perf_{n_epoch:04d}.txt', 'w') as f: + # for i in range(count.size(0)): + # for j in range(count.size(1)): + # f.write(f'{count[i,j]}') + # f.write(" " if j': break @@ -429,17 +466,6 @@ log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)') ###################################################################### -if args.optim == 'sgd': - optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate) -elif args.optim == 'adam': - optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate) -elif args.optim == 'adamw': - optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate) -else: - raise ValueError(f'Unknown optimizer {args.optim}.') - -###################################################################### - nb_epochs_finished = 0 if args.no_checkpoint: @@ -447,10 +473,12 @@ if args.no_checkpoint: else: try: - checkpoint = torch.load(args.checkpoint_name, map_location = device) + checkpoint = torch.load(args.checkpoint_name) nb_epochs_finished = checkpoint['nb_epochs_finished'] model.load_state_dict(checkpoint['model_state']) - optimizer.load_state_dict(checkpoint['optimizer_state']) + torch.set_rng_state(checkpoint['rng_state']) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(checkpoint['cuda_rng_state']) log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.') except FileNotFoundError: @@ -471,7 +499,24 @@ token_probas = token_count / token_count.sum() entropy = -torch.xlogy(token_probas, token_probas).sum() train_set_perplexity = math.exp(entropy) -for k in range(nb_epochs_finished, nb_epochs): +for n_epoch in range(nb_epochs_finished, nb_epochs): + + if args.learning_rate_end < 0: + lr = args.learning_rate + else: + u = n_epoch / (nb_epochs - 1) + lr = math.exp((1 - u) * math.log(args.learning_rate) + + u * math.log(args.learning_rate_end)) + log_string(f'learning_rate {lr}') + + if args.optim == 'sgd': + optimizer = torch.optim.SGD(model.parameters(), lr = lr) + elif args.optim == 'adam': + optimizer = torch.optim.Adam(model.parameters(), lr = lr) + elif args.optim == 'adamw': + optimizer = torch.optim.AdamW(model.parameters(), lr = lr) + else: + raise ValueError(f'Unknown optimizer {args.optim}.') model.train() @@ -504,16 +549,19 @@ for k in range(nb_epochs_finished, nb_epochs): train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples)) test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples)) - log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}') + log_string(f'perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}') - task.produce_results(k, model) + task.produce_results(n_epoch, model) checkpoint = { - 'nb_epochs_finished': k + 1, + 'nb_epochs_finished': n_epoch + 1, 'model_state': model.state_dict(), - 'optimizer_state': optimizer.state_dict() + 'rng_state': torch.get_rng_state(), } + if torch.cuda.is_available(): + checkpoint['cuda_rng_state'] = torch.cuda.get_rng_state() + torch.save(checkpoint, args.checkpoint_name) ######################################################################