X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=cd0e1ead475e79d1f7ac5aef504d9f8395f99001;hb=e36adb51bcc003b4189d92d6c5a31a1d86fe8837;hp=7ce80a31728f49acfb5cd026afda6b5feed7890d;hpb=086ec8f8d2ffeaac270fbedd991bb79122db7fdf;p=mygpt.git diff --git a/main.py b/main.py index 7ce80a3..cd0e1ea 100755 --- a/main.py +++ b/main.py @@ -24,9 +24,6 @@ parser = argparse.ArgumentParser(description = 'My own GPT.') parser.add_argument('--log_filename', type = str, default = 'train.log') -parser.add_argument('--download', - action='store_true', default = False) - parser.add_argument('--seed', type = int, default = 0) @@ -118,17 +115,18 @@ def autoregression( nb_samples, nb_tokens_to_generate, starting_input = None, device = torch.device('cpu') ): - first = 0 results = torch.zeros( nb_samples, nb_tokens_to_generate, dtype = torch.int64, device = device ) - if starting_input is not None: + if starting_input is None: + first = 0 + else: first = starting_input.size(1) results = torch.cat((starting_input, results), 1) - for input in results.split(self.batch_size): + for input in results.split(args.batch_size): for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'): output = model(input) logits = output[:, s] @@ -445,7 +443,7 @@ else: nb_epochs_finished = 0 if args.no_checkpoint: - log_string(f'Not trying to load checkpoint.') + log_string(f'not trying to load checkpoint.') else: try: @@ -453,13 +451,13 @@ else: nb_epochs_finished = checkpoint['nb_epochs_finished'] model.load_state_dict(checkpoint['model_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) - log_string(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.') + log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.') except FileNotFoundError: - log_string('Starting from scratch.') + log_string('starting from scratch.') except: - log_string('Error when loading the checkpoint.') + log_string('error when loading the checkpoint.') exit(1) ###################################################################### @@ -472,7 +470,7 @@ for input in task.batches(split = 'train'): token_probas = token_count / token_count.sum() h = -torch.xlogy(token_probas, token_probas).sum() train_set_perplexity = math.exp(h) -log_string(f'Train set perplexity {train_set_perplexity}') +log_string(f'train set perplexity {train_set_perplexity}') for k in range(nb_epochs_finished, nb_epochs):