X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=main.py;h=82caebe284b6c966f3b1653aa396091edef75e46;hb=e6f60ce506f75da652d069a4926adfa4d92e3cbf;hp=1cd7342184b02fa904c60c75f961d93664cc95e8;hpb=0553fc413c4af68bc777b70a6236c622a3b5242f;p=mygpt.git diff --git a/main.py b/main.py index 1cd7342..82caebe 100755 --- a/main.py +++ b/main.py @@ -24,9 +24,6 @@ parser = argparse.ArgumentParser(description = 'My own GPT.') parser.add_argument('--log_filename', type = str, default = 'train.log') -parser.add_argument('--download', - action='store_true', default = False) - parser.add_argument('--seed', type = int, default = 0) @@ -173,6 +170,7 @@ class TaskPicoCLVR(Task): descr = [ s.strip().split(' ') for s in descr ] l = max([ len(s) for s in descr ]) + #descr = [ [ '' ] * (l - len(s)) + s for s in descr ] descr = [ s + [ '' ] * (l - len(s)) for s in descr ] return descr @@ -194,6 +192,7 @@ class TaskPicoCLVR(Task): self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ]) self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ]) + # Tokenize the train and test sets t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ] self.train_input = torch.tensor(t, device = self.device) t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ] @@ -446,7 +445,7 @@ else: nb_epochs_finished = 0 if args.no_checkpoint: - log_string(f'Not trying to load checkpoint.') + log_string(f'not trying to load checkpoint.') else: try: @@ -454,13 +453,13 @@ else: nb_epochs_finished = checkpoint['nb_epochs_finished'] model.load_state_dict(checkpoint['model_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) - log_string(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.') + log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.') except FileNotFoundError: - log_string('Starting from scratch.') + log_string('starting from scratch.') except: - log_string('Error when loading the checkpoint.') + log_string('error when loading the checkpoint.') exit(1) ###################################################################### @@ -473,7 +472,7 @@ for input in task.batches(split = 'train'): token_probas = token_count / token_count.sum() h = -torch.xlogy(token_probas, token_probas).sum() train_set_perplexity = math.exp(h) -log_string(f'Train set perplexity {train_set_perplexity}') +log_string(f'train set perplexity {train_set_perplexity}') for k in range(nb_epochs_finished, nb_epochs):