X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=6c1def76800aba41b7462d0ec977461562207d67;hb=0dbca4cef7405fb92689e5d2542f1d4761d658a3;hp=ace376da96dcdffb2fbceec73044052bbdc99aa9;hpb=98d2184fb3f202d0f513380ca00d080b64cf5e90;p=mygpt.git diff --git a/main.py b/main.py index ace376d..6c1def7 100755 --- a/main.py +++ b/main.py @@ -69,6 +69,9 @@ parser.add_argument('--dropout', parser.add_argument('--synthesis_sampling', action='store_true', default = True) +parser.add_argument('--no_checkpoint', + action='store_true', default = False) + parser.add_argument('--checkpoint_name', type = str, default = 'checkpoint.pth') @@ -130,36 +133,40 @@ class TaskPicoCLVR(Task): height, width, many_colors = False, device = torch.device('cpu')): + def generate_descr(nb): + descr = picoclvr.generate( + nb, + height = self.height, width = self.width, + many_colors = many_colors + ) + + descr = [ s.strip().split(' ') for s in descr ] + l = max([ len(s) for s in descr ]) + descr = [ s + [ '' ] * (l - len(s)) for s in descr ] + + return descr + self.height = height self.width = width self.batch_size = batch_size self.device = device nb = args.data_size if args.data_size > 0 else 250000 - descr = picoclvr.generate( - nb, - height = self.height, width = self.width, - many_colors = many_colors - ) - - # self.test_descr = descr[:nb // 5] - # self.train_descr = descr[nb // 5:] - - descr = [ s.strip().split(' ') for s in descr ] - l = max([ len(s) for s in descr ]) - descr = [ s + [ '' ] * (l - len(s)) for s in descr ] + self.train_descr = generate_descr((nb * 4) // 5) + self.test_descr = generate_descr((nb * 1) // 5) + # Build the tokenizer tokens = set() - for s in descr: - for t in s: tokens.add(t) + for d in [ self.train_descr, self.test_descr ]: + for s in d: + for t in s: tokens.add(t) self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ]) self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ]) - t = [ [ self.token2id[u] for u in s ] for s in descr ] - data_input = torch.tensor(t, device = self.device) - - self.test_input = data_input[:nb // 5] - self.train_input = data_input[nb // 5:] + t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ] + self.train_input = torch.tensor(t, device = self.device) + t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ] + self.test_input = torch.tensor(t, device = self.device) def batches(self, split = 'train'): assert split in { 'train', 'test' } @@ -209,12 +216,20 @@ class TaskPicoCLVR(Task): img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ] img = torch.cat(img, 0) - file_name = f'result_picoclvr_{n_epoch:04d}.png' - torchvision.utils.save_image(img / 255., - file_name, nrow = nb_per_primer, pad_value = 0.8) - log_string(f'wrote {file_name}') + image_name = f'result_picoclvr_{n_epoch:04d}.png' + torchvision.utils.save_image( + img / 255., + image_name, nrow = nb_per_primer, pad_value = 0.8 + ) + log_string(f'wrote {image_name}') + + nb_missing = sum( [ + x[2] for x in picoclvr.nb_missing_properties( + descr, + height = self.height, width = self.width + ) + ] ) - nb_missing = sum( [ x[2] for x in picoclvr.nb_missing_properties(descr, height = self.height, width = self.width) ] ) log_string(f'nb_missing {nb_missing / len(descr):.02f}') ###################################################################### @@ -418,19 +433,23 @@ else: nb_epochs_finished = 0 -try: - checkpoint = torch.load(args.checkpoint_name, map_location = device) - nb_epochs_finished = checkpoint['nb_epochs_finished'] - model.load_state_dict(checkpoint['model_state']) - optimizer.load_state_dict(checkpoint['optimizer_state']) - print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.') - -except FileNotFoundError: - print('Starting from scratch.') +if args.no_checkpoint: + log_string(f'Not trying to load checkpoint.') -except: - print('Error when loading the checkpoint.') - exit(1) +else: + try: + checkpoint = torch.load(args.checkpoint_name, map_location = device) + nb_epochs_finished = checkpoint['nb_epochs_finished'] + model.load_state_dict(checkpoint['model_state']) + optimizer.load_state_dict(checkpoint['optimizer_state']) + log_string(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.') + + except FileNotFoundError: + log_string('Starting from scratch.') + + except: + log_string('Error when loading the checkpoint.') + exit(1) ######################################################################