X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=d4a8cfb06c846e60f7f950d3cc8dadf2b9a7ce1b;hb=f3a734b6c522b2be0004a1b8bc2fe2eab2a90263;hp=427a83a05c6aef7f55413a7a1f225fb11efbb451;hpb=6260a3593ac09a9bbdd9c85b23d78a71fa028acd;p=mygpt.git diff --git a/main.py b/main.py index 427a83a..d4a8cfb 100755 --- a/main.py +++ b/main.py @@ -165,6 +165,12 @@ class TaskPicoCLVR(Task): id_descr = [ [ self.token2id[u] for u in s ] for s in token_descr ] return torch.tensor(id_descr, device = self.device) + def trim(self, x, token = ''): + n = self.token2id[token] + i = (1 - (F.pad(x, (1, 1), value = n) == n).min(0).values.long()).cumsum(0) + a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min() + return x[:, a:b] + def __init__(self, batch_size, height, width, nb_colors = 5, device = torch.device('cpu')): @@ -182,6 +188,7 @@ class TaskPicoCLVR(Task): self.device = device nb = args.data_size if args.data_size > 0 else 250000 + log_string(f'generating {nb} samples (can take some time)') self.train_descr = generate_descr((nb * 4) // 5) self.test_descr = generate_descr((nb * 1) // 5) @@ -201,13 +208,13 @@ class TaskPicoCLVR(Task): assert split in { 'train', 'test' } input = self.train_input if split == 'train' else self.test_input for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'): - yield batch + yield self.trim(batch) def vocabulary_size(self): return len(self.token2id) def produce_results(self, n_epoch, model): - nb_tokens = self.height * self.width + 3 + nb_tokens_to_generate = self.height * self.width + 3 result_descr = [ ] nb_per_primer = 8 @@ -218,15 +225,26 @@ class TaskPicoCLVR(Task): 'green bottom yellow bottom green left of blue yellow right of blue blue top ', ]: - for k in range(nb_per_primer): - results = autoregression( - model, self.batch_size, - nb_samples = 1, nb_tokens_to_generate = nb_tokens, - primer = self.tensorize([ primer_descr ]), - device = self.device - ) - r = ' '.join([ self.id2token[t.item()] for t in results.flatten() ]) - result_descr.append(r) + results = autoregression( + model, + self.batch_size, + nb_samples = nb_per_primer, + nb_tokens_to_generate = nb_tokens_to_generate, + primer = self.tensorize([ primer_descr ]).expand(nb_per_primer, -1), + device = self.device + ) + + l = [ ' '.join([ self.id2token[t.item()] for t in r ]) for r in results ] + result_descr += l + + np = picoclvr.nb_properties( + result_descr, + height = self.height, width = self.width + ) + + nb_requested_properties, _, nb_missing_properties = zip(*np) + + log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}') img = [ picoclvr.descr2img(d, height = self.height, width = self.width) @@ -241,15 +259,6 @@ class TaskPicoCLVR(Task): ) log_string(f'wrote {image_name}') - np = picoclvr.nb_properties( - result_descr, - height = self.height, width = self.width - ) - - nb_requested_properties, _, nb_missing_properties = zip(*np) - - log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}') - ###################################################################### class TaskWiki103(Task): @@ -421,17 +430,6 @@ log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)') ###################################################################### -if args.optim == 'sgd': - optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate) -elif args.optim == 'adam': - optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate) -elif args.optim == 'adamw': - optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate) -else: - raise ValueError(f'Unknown optimizer {args.optim}.') - -###################################################################### - nb_epochs_finished = 0 if args.no_checkpoint: @@ -439,10 +437,12 @@ if args.no_checkpoint: else: try: - checkpoint = torch.load(args.checkpoint_name, map_location = device) + checkpoint = torch.load(args.checkpoint_name) nb_epochs_finished = checkpoint['nb_epochs_finished'] model.load_state_dict(checkpoint['model_state']) - optimizer.load_state_dict(checkpoint['optimizer_state']) + torch.set_rng_state(checkpoint['rng_state']) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(checkpoint['cuda_rng_state']) log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.') except FileNotFoundError: @@ -462,9 +462,17 @@ for input in task.batches(split = 'train'): token_probas = token_count / token_count.sum() entropy = -torch.xlogy(token_probas, token_probas).sum() train_set_perplexity = math.exp(entropy) -#log_string(f'train set perplexity {train_set_perplexity}') -for k in range(nb_epochs_finished, nb_epochs): +for n_epoch in range(nb_epochs_finished, nb_epochs): + + if args.optim == 'sgd': + optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate) + elif args.optim == 'adam': + optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate) + elif args.optim == 'adamw': + optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate) + else: + raise ValueError(f'Unknown optimizer {args.optim}.') model.train() @@ -497,16 +505,19 @@ for k in range(nb_epochs_finished, nb_epochs): train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples)) test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples)) - log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}') + log_string(f'perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}') - task.produce_results(k, model) + task.produce_results(n_epoch, model) checkpoint = { - 'nb_epochs_finished': k + 1, + 'nb_epochs_finished': n_epoch + 1, 'model_state': model.state_dict(), - 'optimizer_state': optimizer.state_dict() + 'rng_state': torch.get_rng_state(), } + if torch.cuda.is_available(): + checkpoint['cuda_rng_state'] = torch.cuda.get_rng_state() + torch.save(checkpoint, args.checkpoint_name) ######################################################################